Skip to content

Commit

Permalink
Merge pull request #151 from edpaget/fix-nested-queries
Browse files Browse the repository at this point in the history
Fix WhereChain operators to work on joins
  • Loading branch information
danmcclain committed Jan 13, 2015
2 parents e511807 + 9daaf79 commit 2b02b61
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 64 deletions.
2 changes: 1 addition & 1 deletion gemfiles/Gemfile.activerecord-4.2.x
Expand Up @@ -2,7 +2,7 @@ source "https://rubygems.org"

gemspec :path => '..'

gem "activerecord", "~> 4.2.0.beta2"
gem "activerecord", "~> 4.2.0"
gem "pg", "~> 0.15"

unless ENV['CI'] || RUBY_PLATFORM =~ /java/
Expand Down
141 changes: 94 additions & 47 deletions lib/postgres_ext/active_record/relation/query_methods.rb
@@ -1,70 +1,117 @@
module ActiveRecord
module QueryMethods
class WhereChain
def overlap(opts)
opts.each do |key, value|
@scope = @scope.where(arel_table[key].overlap(value))
end
@scope
def overlap(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::Overlap, 'overlap')
end

def contained_within(opts)
opts.each do |key, value|
@scope = @scope.where(arel_table[key].contained_within(value))
end
def contained_within(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::ContainedWithin, 'contained_within')
end

@scope
def contained_within_or_equals(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::ContainedWithinEquals, 'contained_within_or_equals')
end

def contained_within_or_equals(opts)
opts.each do |key, value|
@scope = @scope.where(arel_table[key].contained_within_or_equals(value))
def contains(opts, *rest)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::In, Arel::Nodes::Equality
column = left_column(rel) || column_from_association(rel)
equality_for_hstore(rel) if column.type == :hstore

if column.type == :hstore
Arel::Nodes::ContainsHStore.new(rel.left, rel.right)
elsif column.respond_to?(:array) && column.array
Arel::Nodes::ContainsArray.new(rel.left, rel.right)
else
Arel::Nodes::ContainsINet.new(rel.left, rel.right)
end
else
raise ArgumentError, "Invalid argument for .where.overlap(), got #{rel.class}"
end
end
end

@scope
def contains_or_equals(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::ContainsEquals, 'contains_or_equals')
end

def contains(opts)
opts.each do |key, value|
@scope = @scope.where(arel_table[key].contains(value))
end
def any(opts, *rest)
equality_to_function('ANY', opts, rest)
end

@scope
def all(opts, *rest)
equality_to_function('ALL', opts, rest)
end

def contains_or_equals(opts)
opts.each do |key, value|
@scope = @scope.where(arel_table[key].contains_or_equals(value))
end
private

@scope
def find_column(col, rel)
col.name == rel.left.name.to_s || col.name == rel.left.relation.name.to_s
end

def left_column(rel)
rel.left.relation.engine.columns.find { |col| find_column(col, rel) }
end

def any(opts)
equality_to_function('ANY', opts)
def column_from_association(rel)
if assoc = assoc_from_related_table(rel)
column = assoc.klass.columns.find { |col| find_column(col, rel) }
end
end

def all(opts)
equality_to_function('ALL', opts)
def equality_for_hstore(rel)
new_right_name = rel.left.name.to_s
if rel.right.respond_to?(:val)
return if rel.right.val.is_a?(Hash)
rel.right = Arel::Nodes.build_quoted({new_right_name => rel.right.val},
rel.left)
else
return if rel.right.is_a?(Hash)
rel.right = {new_right_name => rel.right }
end

rel.left.name = rel.left.relation.name.to_sym
rel.left.relation.name = rel.left.relation.engine.table_name
end

private
def assoc_from_related_table(rel)
engine = rel.left.relation.engine
engine.reflect_on_association(rel.left.relation.name.to_sym) ||
engine.reflect_on_association(rel.left.relation.name.singularize.to_sym)
end

def arel_table
@arel_table ||= @scope.engine.arel_table
def build_where_chain(opts, rest, &block)
where_value = @scope.send(:build_where, opts, rest).map(&block)
@scope.references!(PredicateBuilder.references(opts)) if Hash === opts
@scope.where_values += where_value
@scope
end

def equality_to_function(function_name, opts)
opts.each do |key, value|
any_function = Arel::Nodes::NamedFunction.new(function_name, [arel_table[key]])
predicate = Arel::Nodes::Equality.new(value, any_function)
@scope = @scope.where(predicate)
def substitute_comparisons(opts, rest, arel_node_class, method)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::In, Arel::Nodes::Equality
arel_node_class.new(rel.left, rel.right)
else
raise ArgumentError, "Invalid argument for .where.#{method}(), got #{rel.class}"
end
end
end

@scope
def equality_to_function(function_name, opts, rest)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::Equality
Arel::Nodes::Equality.new(rel.right, Arel::Nodes::NamedFunction.new(function_name, [rel.left]))
else
raise ArgumentError, "Invalid argument for .where.#{funciton_name.downcase}(), got #{rel.class}"
end
end
end
end

# WithChain objects act as placeholder for queries in which #with does not have any parameter.
# In this case, #with must be chained with #recursive to return a new relation.
class WithChain
Expand Down Expand Up @@ -173,15 +220,15 @@ def build_with(arel)
def build_rank(arel, rank_window_options)
unless arel.projections.count == 1 && Arel::Nodes::Count === arel.projections.first
rank_window = case rank_window_options
when :order
arel.orders
when Symbol
table[rank_window_options].asc
when Hash
rank_window_options.map { |field, dir| table[field].send(dir) }
else
Arel::Nodes::SqlLiteral.new "(#{rank_window_options})"
end
when :order
arel.orders
when Symbol
table[rank_window_options].asc
when Hash
rank_window_options.map { |field, dir| table[field].send(dir) }
else
Arel::Nodes::SqlLiteral.new "(#{rank_window_options})"
end

unless rank_window.blank?
rank_node = Arel::Nodes::SqlLiteral.new 'rank()'
Expand Down
1 change: 1 addition & 0 deletions lib/postgres_ext/arel/4.1/visitors.rb
@@ -1 +1,2 @@
require 'postgres_ext/arel/4.1/visitors/depth_first'
require 'postgres_ext/arel/4.1/visitors/postgresql'
9 changes: 9 additions & 0 deletions lib/postgres_ext/arel/4.1/visitors/depth_first.rb
@@ -0,0 +1,9 @@
require 'arel/visitors/depth_first'

module Arel
module Visitors
class DepthFirst
alias :visit_IPAddr :terminal
end
end
end
28 changes: 20 additions & 8 deletions lib/postgres_ext/arel/4.1/visitors/postgresql.rb
Expand Up @@ -4,7 +4,7 @@ module Arel
module Visitors
class PostgreSQL
private

def visit_Array o, a
column = a.relation.engine.connection.schema_cache.columns(a.relation.name).find { |col| col.name == a.name.to_s } if a
if column && column.respond_to?(:array) && column.array
Expand All @@ -13,6 +13,16 @@ def visit_Array o, a
o.empty? ? 'NULL' : o.map { |x| visit x }.join(', ')
end
end

def visit_Arel_Nodes_Contains o, a = nil
left_column = o.left.relation.engine.columns.find { |col| col.name == o.left.name.to_s }

if left_column && (left_column.type == :hstore || (left_column.respond_to?(:array) && left_column.array))
"#{visit o.left, a} @> #{visit o.right, o.left}"
else
"#{visit o.left, a} >> #{visit o.right, o.left}"
end
end

def visit_Arel_Nodes_ContainedWithin o, a = nil
"#{visit o.left, a} << #{visit o.right, o.left}"
Expand All @@ -22,14 +32,16 @@ def visit_Arel_Nodes_ContainedWithinEquals o, a = nil
"#{visit o.left, a} <<= #{visit o.right, o.left}"
end

def visit_Arel_Nodes_Contains o, a = nil
left_column = o.left.relation.engine.columns.find { |col| col.name == o.left.name.to_s }
def visit_Arel_Nodes_ContainsArray o, a = nil
"#{visit o.left, a} @> #{visit o.right, o.left}"
end

if left_column && (left_column.type == :hstore || (left_column.respond_to?(:array) && left_column.array))
"#{visit o.left, a} @> #{visit o.right, o.left}"
else
"#{visit o.left, a} >> #{visit o.right, o.left}"
end
def visit_Arel_Nodes_ContainsHStore o, a = nil
"#{visit o.left, a} @> #{visit o.right, o.left}"
end

def visit_Arel_Nodes_ContainsINet o, a = nil
"#{visit o.left, a} >> #{visit o.right, o.left}"
end

def visit_Arel_Nodes_ContainsEquals o, a = nil
Expand Down
20 changes: 17 additions & 3 deletions lib/postgres_ext/arel/4.2/visitors/postgresql.rb
Expand Up @@ -4,7 +4,7 @@ module Arel
module Visitors
class PostgreSQL
private

def visit_Arel_Nodes_ContainedWithin o, collector
infix_value o, collector, " << "
end
Expand All @@ -14,15 +14,29 @@ def visit_Arel_Nodes_ContainedWithinEquals o, collector
end

def visit_Arel_Nodes_Contains o, collector
left_column = o.left.relation.engine.columns.find { |col| col.name == o.left.name.to_s }
left_column = o.left.relation.engine.columns.find do |col|
col.name == o.left.name.to_s || col.name == o.left.relation.name.to_s
end

if left_column && (left_column.type == :hstore || (left_column.respond_to?(:array) && left_column.array))
infix_value o, collector, " @> "
infix_value o, collector, " @> "
else
infix_value o, collector, " >> "
end
end

def visit_Arel_Nodes_ContainsINet o, collector
infix_value o, collector, " >> "
end

def visit_Arel_Nodes_ContainsHStore o, collector
infix_value o, collector, " @> "
end

def visit_Arel_Nodes_ContainsArray o, collector
infix_value o, collector, " @> "
end

def visit_Arel_Nodes_ContainsEquals o, collector
infix_value o, collector, " >>= "
end
Expand Down
16 changes: 14 additions & 2 deletions lib/postgres_ext/arel/nodes/contained_within.rb
Expand Up @@ -6,15 +6,27 @@ def operator; :<< end
end

class ContainedWithinEquals < Arel::Nodes::Binary
def operator; '<<='.to_sym end
def operator; :"<<=" end
end

class Contains < Arel::Nodes::Binary
def operator; :>> end
end

class ContainsINet < Arel::Nodes::Binary
def operator; :>> end
end

class ContainsHStore < Arel::Nodes::Binary
def operator; :"@>" end
end

class ContainsArray < Arel::Nodes::Binary
def operator; :"@>" end
end

class ContainsEquals < Arel::Nodes::Binary
def operator; '>>='.to_sym end
def operator; :">>=" end
end
end
end
8 changes: 6 additions & 2 deletions test/queries/array_queries_test.rb
Expand Up @@ -15,9 +15,8 @@

describe '.where(joins: { array_column: [] })' do
it 'returns an array string instead of IN ()' do
skip
query = Person.joins(:hm_tags).where(tags: { categories: ['working'] }).to_sql
query.must_match equality_regex
query.must_match %r{\"tags\"\.\"categories\" = '\{"?working"?\}'}
end
end

Expand All @@ -33,6 +32,11 @@
query.must_match overlap_regex
query.must_match equality_regex
end

it 'works on joins' do
query = Person.joins(:hm_tags).where.overlap(tags: { categories: ['working'] }).to_sql
query.must_match %r{\"tags\"\.\"categories\" && '\{"?working"?\}'}
end
end


Expand Down
13 changes: 12 additions & 1 deletion test/queries/contains_test.rb
Expand Up @@ -41,16 +41,27 @@
query.to_sql.must_match contains_array_regex
end

it 'generates the appropriate where clause for array columns' do
it 'generates the appropriate where clause for hstore columns' do
query = Person.where.contains(data: { nickname: 'Dan' })
query.to_sql.must_match contains_hstore_regex
end

it 'generates the appropriate where clause for hstore columns on joins' do
query = Tag.joins(:person).where.contains(people: { data: { nickname: 'Dan' } })
query.to_sql.must_match contains_hstore_regex
end

it 'allows chaining' do
query = Person.where.contains(:tag_ids => [1,2]).where(:tags => ['working']).to_sql

query.must_match contains_array_regex
query.must_match equality_regex
end

it 'generates the appropriate where clause for array columns on joins' do
query = Tag.joins(:person).where.contains(people: { tag_ids: [1,2] }).to_sql

query.must_match contains_array_regex
end
end
end

0 comments on commit 2b02b61

Please sign in to comment.