Skip to content

Commit

Permalink
create_sig takes hash instead of TypedParams
Browse files Browse the repository at this point in the history
TypeParam accepts a default, which is confusing since it's not used in
sigs. SigParam is the right thing to use here, and create_sig will
build these from a hash.
  • Loading branch information
bdewater committed Apr 4, 2024
1 parent 1b8648a commit cdd80d1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 58 deletions.
86 changes: 33 additions & 53 deletions lib/tapioca/dsl/compilers/active_record_relations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -667,11 +667,11 @@ def create_common_methods
end
sigs = [
common_relation_methods_module.create_sig(
parameters: [create_param("args", type: id_types)],
parameters: { args: id_types },
return_type: constant_name,
),
common_relation_methods_module.create_sig(
parameters: [create_param("args", type: array_type)],
parameters: { args: array_type },
return_type: "T::Enumerable[#{constant_name}]",
),
]
Expand Down Expand Up @@ -714,11 +714,11 @@ def create_common_methods
when :first, :last, :take
sigs = [
common_relation_methods_module.create_sig(
parameters: [create_opt_param("limit", type: "NilClass", default: "nil")],
parameters: { limit: "NilClass" },
return_type: as_nilable_type(constant_name),
),
common_relation_methods_module.create_sig(
parameters: [create_param("limit", type: "Integer")],
parameters: { limit: "Integer" },
return_type: "T::Array[#{constant_name}]",
),
]
Expand Down Expand Up @@ -819,26 +819,20 @@ def create_common_methods
case method_name
when :find_each
order = ActiveRecord::Batches.instance_method(:find_each).parameters.include?([:key, :order])
parameters = {
start: "T.untyped",
finish: "T.untyped",
batch_size: "Integer",
error_on_ignore: "T.untyped",
order: ("Symbol" if order),
}.compact
sigs = [
common_relation_methods_module.create_sig(
parameters: [
create_kw_opt_param("start", type: "T.untyped", default: "nil"),
create_kw_opt_param("finish", type: "T.untyped", default: "nil"),
create_kw_opt_param("batch_size", type: "Integer", default: "1000"),
create_kw_opt_param("error_on_ignore", type: "T.untyped", default: "nil"),
*(create_kw_opt_param("order", type: "Symbol", default: ":asc") if order),
create_block_param("block", type: "T.proc.params(object: #{constant_name}).void"),
],
parameters: parameters.merge(block: "T.proc.params(object: #{constant_name}).void"),
return_type: "void",
),
common_relation_methods_module.create_sig(
parameters: [
create_kw_opt_param("start", type: "T.untyped", default: "nil"),
create_kw_opt_param("finish", type: "T.untyped", default: "nil"),
create_kw_opt_param("batch_size", type: "Integer", default: "1000"),
create_kw_opt_param("error_on_ignore", type: "T.untyped", default: "nil"),
*(create_kw_opt_param("order", type: "Symbol", default: ":asc") if order),
],
parameters: parameters,
return_type: "T::Enumerator[#{constant_name}]",
),
]
Expand All @@ -856,26 +850,20 @@ def create_common_methods
)
when :find_in_batches
order = ActiveRecord::Batches.instance_method(:find_in_batches).parameters.include?([:key, :order])
parameters = {
start: "T.untyped",
finish: "T.untyped",
batch_size: "Integer",
error_on_ignore: "T.untyped",
order: ("Symbol" if order),
}.compact
sigs = [
common_relation_methods_module.create_sig(
parameters: [
create_kw_opt_param("start", type: "T.untyped", default: "nil"),
create_kw_opt_param("finish", type: "T.untyped", default: "nil"),
create_kw_opt_param("batch_size", type: "Integer", default: "1000"),
create_kw_opt_param("error_on_ignore", type: "T.untyped", default: "nil"),
*(create_kw_opt_param("order", type: "Symbol", default: ":asc") if order),
create_block_param("block", type: "T.proc.params(object: T::Array[#{constant_name}]).void"),
],
parameters: parameters.merge(block: "T.proc.params(object: T::Array[#{constant_name}]).void"),
return_type: "void",
),
common_relation_methods_module.create_sig(
parameters: [
create_kw_opt_param("start", type: "T.untyped", default: "nil"),
create_kw_opt_param("finish", type: "T.untyped", default: "nil"),
create_kw_opt_param("batch_size", type: "Integer", default: "1000"),
create_kw_opt_param("error_on_ignore", type: "T.untyped", default: "nil"),
*(create_kw_opt_param("order", type: "Symbol", default: ":asc") if order),
],
parameters: parameters,
return_type: "T::Enumerator[T::Enumerator[#{constant_name}]]",
),
]
Expand All @@ -894,30 +882,22 @@ def create_common_methods
when :in_batches
order = ActiveRecord::Batches.instance_method(:in_batches).parameters.include?([:key, :order])
use_ranges = ActiveRecord::Batches.instance_method(:in_batches).parameters.include?([:key, :use_ranges])
parameters = {
of: "Integer",
start: "T.untyped",
finish: "T.untyped",
load: "T.untyped",
error_on_ignore: "T.untyped",
order: ("Symbol" if order),
use_ranges: ("T.untyped" if use_ranges),
}.compact
sigs = [
common_relation_methods_module.create_sig(
parameters: [
create_kw_opt_param("of", type: "Integer", default: "1000"),
create_kw_opt_param("start", type: "T.untyped", default: "nil"),
create_kw_opt_param("finish", type: "T.untyped", default: "nil"),
create_kw_opt_param("load", type: "T.untyped", default: "false"),
create_kw_opt_param("error_on_ignore", type: "T.untyped", default: "nil"),
*(create_kw_opt_param("order", type: "Symbol", default: ":asc") if order),
*(create_kw_opt_param("use_ranges", type: "T.untyped", default: "nil") if use_ranges),
create_block_param("block", type: "T.proc.params(object: #{RelationClassName}).void"),
],
parameters: parameters.merge(block: "T.proc.params(object: #{RelationClassName}).void"),
return_type: "void",
),
common_relation_methods_module.create_sig(
parameters: [
create_kw_opt_param("of", type: "Integer", default: "1000"),
create_kw_opt_param("start", type: "T.untyped", default: "nil"),
create_kw_opt_param("finish", type: "T.untyped", default: "nil"),
create_kw_opt_param("load", type: "T.untyped", default: "false"),
create_kw_opt_param("error_on_ignore", type: "T.untyped", default: "nil"),
*(create_kw_opt_param("order", type: "Symbol", default: ":asc") if order),
*(create_kw_opt_param("use_ranges", type: "T.untyped", default: "nil") if use_ranges),
],
parameters: parameters,
return_type: "::ActiveRecord::Batches::BatchEnumerator",
),
]
Expand Down
11 changes: 6 additions & 5 deletions lib/tapioca/rbi_ext/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def create_type_variable(name, type:, variance: :invariant, fixed: nil, upper: n
end
def create_method(name, parameters: [], return_type: "T.untyped", class_method: false, visibility: RBI::Public.new,
comments: [])
sig = create_sig(parameters: parameters, return_type: return_type)
sig_params = parameters.to_h { |param| [param.param.name, param.type] }
sig = create_sig(parameters: sig_params, return_type: return_type)
create_method_with_sigs(
name,
sigs: [sig],
Expand Down Expand Up @@ -126,13 +127,13 @@ def create_method_with_sigs(name, sigs:, parameters: [], class_method: false, vi

sig do
params(
parameters: T::Array[RBI::TypedParam],
parameters: T::Hash[T.any(String, Symbol), String],
return_type: String,
).returns(RBI::Sig)
end
def create_sig(parameters: [], return_type: "T.untyped")
params = parameters.map do |param|
RBI::SigParam.new(param.param.name, param.type)
def create_sig(parameters:, return_type: "T.untyped")
params = parameters.map do |name, type|
RBI::SigParam.new(name.to_s, type)
end
RBI::Sig.new(params: params, return_type: return_type)
end
Expand Down

0 comments on commit cdd80d1

Please sign in to comment.