Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start passing in an option to ActiveRecordColumns compiler for how to generate column types #1888

Merged
merged 14 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ dsl:
app_root: "."
halt_upon_load_error: true
skip_constant: []
compiler_options: {}
gem:
outdir: sorbet/rbi/gems
file_header: true
Expand Down
29 changes: 29 additions & 0 deletions lib/tapioca/cli.rb
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,20 @@ def todo
banner: "constant [constant ...]",
desc: "Do not generate RBI definitions for the given application constant(s)",
default: []
option :compiler_options,
type: :hash,
desc: "Options to pass to the DSL compilers",
hide: true,
default: {}
def dsl(*constant_or_paths)
set_environment(options)

# Assume anything starting with a capital letter or colon is a class, otherwise a path
constants, paths = constant_or_paths.partition { |c| c =~ /\A[A-Z:]/ }

# Make sure compiler options are received as a hash
compiler_options = process_compiler_options

command_args = {
requested_constants: constants,
requested_paths: paths.map { |p| Pathname.new(p) },
Expand All @@ -161,6 +169,7 @@ def dsl(*constant_or_paths)
rbi_formatter: rbi_formatter(options),
app_root: options[:app_root],
halt_upon_load_error: options[:halt_upon_load_error],
compiler_options: compiler_options,
}

command = if options[:verify]
Expand Down Expand Up @@ -372,6 +381,26 @@ def exit_on_failure?

private

def process_compiler_options
compiler_options = options[:compiler_options]

# Parse all compiler option hash values as YAML if they are Strings
compiler_options.transform_values! do |value|
value = YAML.safe_load(value) if String === value
value
rescue YAML::Exception
raise MalformattedArgumentError,
"Option '--compiler-options' should have well-formatted YAML strings, but received: '#{value}'"
end

unless compiler_options.values.all? { |v| Hash === v }
raise MalformattedArgumentError,
"Option '--compiler-options' should be a hash of hashes, but received: '#{compiler_options}'"
end

compiler_options
end

def print_init_next_steps
say(<<~OUTPUT)
#{set_color("This project is now set up for use with Sorbet and Tapioca", :bold)}
Expand Down
6 changes: 5 additions & 1 deletion lib/tapioca/commands/abstract_dsl.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AbstractDsl < CommandWithoutTracker
rbi_formatter: RBIFormatter,
app_root: String,
halt_upon_load_error: T::Boolean,
compiler_options: T::Hash[String, T.untyped],
).void
end
def initialize(
Expand All @@ -45,7 +46,8 @@ def initialize(
gem_dir: DEFAULT_GEM_DIR,
rbi_formatter: DEFAULT_RBI_FORMATTER,
app_root: ".",
halt_upon_load_error: true
halt_upon_load_error: true,
compiler_options: {}
)
@requested_constants = requested_constants
@requested_paths = requested_paths
Expand All @@ -63,6 +65,7 @@ def initialize(
@app_root = app_root
@halt_upon_load_error = halt_upon_load_error
@skip_constant = skip_constant
@compiler_options = compiler_options

super()
end
Expand Down Expand Up @@ -129,6 +132,7 @@ def create_pipeline
},
skipped_constants: constantize(@skip_constant),
number_of_workers: @number_of_workers,
compiler_options: @compiler_options,
)
end

Expand Down
15 changes: 13 additions & 2 deletions lib/tapioca/dsl/compiler.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Compiler
sig { returns(RBI::Tree) }
attr_reader :root

sig { returns(T::Hash[String, T.untyped]) }
attr_reader :options

class << self
extend T::Sig

Expand Down Expand Up @@ -60,11 +63,19 @@ def all_modules
end
end

sig { params(pipeline: Tapioca::Dsl::Pipeline, root: RBI::Tree, constant: ConstantType).void }
def initialize(pipeline, root, constant)
sig do
params(
pipeline: Tapioca::Dsl::Pipeline,
root: RBI::Tree,
constant: ConstantType,
options: T::Hash[String, T.untyped],
).void
end
def initialize(pipeline, root, constant, options = {})
@pipeline = pipeline
@root = root
@constant = constant
@options = options
@errors = T.let([], T::Array[String])
end

Expand Down
15 changes: 14 additions & 1 deletion lib/tapioca/dsl/compilers/active_record_columns.rb
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,19 @@ def gather_constants

private

sig { returns(Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption) }
def column_type_option
@column_type_option ||= T.let(
Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption.from_serialized(
options.fetch(
"types",
Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption::Persisted.serialize,
),
),
T.nilable(Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption),
)
end

sig do
params(
klass: RBI::Scope,
Expand Down Expand Up @@ -174,7 +187,7 @@ def add_method(klass, name, methods_to_add, return_type: "void", parameters: [])
end
def add_methods_for_attribute(klass, attribute_name, column_name = attribute_name, methods_to_add = nil)
getter_type, setter_type = Helpers::ActiveRecordColumnTypeHelper
.new(constant)
.new(constant, column_type_option: column_type_option)
.type_for(attribute_name, column_name)

# Added by ActiveRecord::AttributeMethods::Read
Expand Down
5 changes: 4 additions & 1 deletion lib/tapioca/dsl/compilers/identity_cache.rb
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def create_index_fetch_by_methods(field, klass)
).void
end
def create_aliased_fetch_by_methods(field, klass)
type, _ = Helpers::ActiveRecordColumnTypeHelper.new(constant).type_for(field.alias_name.to_s)
type, _ = Helpers::ActiveRecordColumnTypeHelper.new(
KaanOzkan marked this conversation as resolved.
Show resolved Hide resolved
constant,
column_type_option: Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption::Nilable,
).type_for(field.alias_name.to_s)
multi_type = type.delete_prefix("T.nilable(").delete_suffix(")").delete_prefix("::")
suffix = field.send(:fetch_method_suffix)

Expand Down
69 changes: 48 additions & 21 deletions lib/tapioca/dsl/helpers/active_record_column_type_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,47 @@ class ActiveRecordColumnTypeHelper
extend T::Sig
include RBIHelper

sig { params(constant: T.class_of(ActiveRecord::Base)).void }
def initialize(constant)
class ColumnTypeOption < T::Enum
extend T::Sig
enums do
Untyped = new("untyped")
Nilable = new("nilable")
Persisted = new("persisted")
end

sig { returns(T::Boolean) }
def persisted?
self == ColumnTypeOption::Persisted
end

sig { returns(T::Boolean) }
def nilable?
self == ColumnTypeOption::Nilable
end

sig { returns(T::Boolean) }
def untyped?
self == ColumnTypeOption::Untyped
end
end

sig do
params(
constant: T.class_of(ActiveRecord::Base),
column_type_option: ColumnTypeOption,
).void
end
def initialize(constant, column_type_option: ColumnTypeOption::Persisted)
@constant = constant
@column_type_option = column_type_option
end

sig { params(attribute_name: String, column_name: String).returns([String, String]) }
sig do
params(
attribute_name: String,
column_name: String,
).returns([String, String])
end
def type_for(attribute_name, column_name = attribute_name)
return id_type if attribute_name == "id"

Expand All @@ -27,15 +62,19 @@ def type_for(attribute_name, column_name = attribute_name)
sig { returns([String, String]) }
def id_type
if @constant.respond_to?(:composite_primary_key?) && T.unsafe(@constant).composite_primary_key?
@constant.primary_key.map(&method(:column_type_for)).map { |tuple| "[#{tuple.join(", ")}]" }
@constant.primary_key.map do |column|
column_type_for(column)
end.map do |tuple|
"[#{tuple.join(", ")}]"
end
else
column_type_for(@constant.primary_key)
end
end

sig { params(column_name: String).returns([String, String]) }
def column_type_for(column_name)
return ["T.untyped", "T.untyped"] if do_not_generate_strong_types?(@constant)
return ["T.untyped", "T.untyped"] if @column_type_option.untyped?

column = @constant.columns_hash[column_name]
column_type = @constant.attribute_types[column_name]
Expand All @@ -48,18 +87,12 @@ def column_type_for(column_name)
getter_type
end

if column&.null
if @column_type_option.persisted? && !column&.null
[getter_type, setter_type]
else
getter_type = as_nilable_type(getter_type) unless not_nilable_serialized_column?(column_type)
return [getter_type, as_nilable_type(setter_type)]
[getter_type, as_nilable_type(setter_type)]
end

if Array(@constant.primary_key).include?(column_name) ||
column_name == "created_at" ||
column_name == "updated_at"
getter_type = as_nilable_type(getter_type)
end

[getter_type, setter_type]
end

sig { params(column_type: T.untyped).returns(String) }
Expand Down Expand Up @@ -104,12 +137,6 @@ def type_for_activerecord_value(column_type)
end
end

sig { params(constant: Module).returns(T::Boolean) }
def do_not_generate_strong_types?(constant)
Object.const_defined?(:StrongTypeGeneration) &&
!(constant.singleton_class < Object.const_get(:StrongTypeGeneration))
end

sig { params(column_type: ActiveRecord::Enum::EnumType).returns(String) }
def enum_setter_type(column_type)
# In Rails < 7 this method is private. When support for that is dropped we can call the method directly
Expand Down
12 changes: 10 additions & 2 deletions lib/tapioca/dsl/pipeline.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Pipeline
error_handler: T.proc.params(error: String).void,
skipped_constants: T::Array[Module],
number_of_workers: T.nilable(Integer),
compiler_options: T::Hash[String, T::Hash[String, T.untyped]],
).void
end
def initialize(
Expand All @@ -42,7 +43,8 @@ def initialize(
excluded_compilers: [],
error_handler: $stderr.method(:puts).to_proc,
skipped_constants: [],
number_of_workers: nil
number_of_workers: nil,
compiler_options: {}
)
@active_compilers = T.let(
gather_active_compilers(requested_compilers, excluded_compilers),
Expand All @@ -53,6 +55,7 @@ def initialize(
@error_handler = error_handler
@skipped_constants = skipped_constants
@number_of_workers = number_of_workers
@compiler_options = compiler_options
@errors = T.let([], T::Array[String])
end

Expand Down Expand Up @@ -197,7 +200,12 @@ def rbi_for_constant(constant)
active_compilers.each do |compiler_class|
next unless compiler_class.handles?(constant)

compiler = compiler_class.new(self, file.root, constant)
compiler_key = T.must(compiler_class.name).dup
Tapioca::Dsl::Compilers::NAMESPACES.each do |namespace|
compiler_key.delete_prefix!(namespace)
end
options = @compiler_options.fetch(compiler_key, {})
compiler = compiler_class.new(self, file.root, constant, options)
compiler.decorate
rescue
$stderr.puts("Error: `#{compiler_class.name}` failed to generate RBI for `#{constant}`")
Expand Down
4 changes: 3 additions & 1 deletion lib/tapioca/helpers/config_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def validate_config_options(command_options, config_key, config_options)
when :hash
error_msg = "invalid value for option `#{config_option_key}` for key `#{config_key}` - expected " \
"`Hash[String, String]` but found `#{config_option_value}`"
all_strings = (config_option_value.keys + config_option_value.values).all? { |v| v.is_a?(String) }
values_to_validate = config_option_value.keys
values_to_validate += config_option_value.values unless config_option_key == "compiler_options"
all_strings = values_to_validate.all? { |v| v.is_a?(String) }
next build_error(error_msg) unless all_strings
end
end
Expand Down
22 changes: 16 additions & 6 deletions lib/tapioca/helpers/test/dsl_compiler.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ def activate_other_dsl_compilers(*compiler_classes)
context.activate_other_dsl_compilers(compiler_classes)
end

sig { params(constant_name: T.any(Symbol, String)).returns(String) }
def rbi_for(constant_name)
context.rbi_for(constant_name)
sig do
params(
constant_name: T.any(Symbol, String),
compiler_options: T::Hash[Symbol, T.untyped],
).returns(String)
end
def rbi_for(constant_name, compiler_options: {})
context.rbi_for(constant_name, compiler_options: compiler_options)
end

sig { returns(T::Array[String]) }
Expand Down Expand Up @@ -85,8 +90,13 @@ def gathered_constants
compiler_class.processable_constants.filter_map(&:name).sort
end

sig { params(constant_name: T.any(Symbol, String)).returns(String) }
def rbi_for(constant_name)
sig do
params(
constant_name: T.any(Symbol, String),
compiler_options: T::Hash[Symbol, T.untyped],
).returns(String)
end
def rbi_for(constant_name, compiler_options: {})
# Make sure this is a constant that we can handle.
unless gathered_constants.include?(constant_name.to_s)
raise "`#{constant_name}` is not processable by the `#{compiler_class}` compiler."
Expand All @@ -95,7 +105,7 @@ def rbi_for(constant_name)
file = RBI::File.new(strictness: "strong")
constant = Object.const_get(constant_name)

compiler = compiler_class.new(pipeline, file.root, constant)
compiler = compiler_class.new(pipeline, file.root, constant, compiler_options.transform_keys(&:to_s))
compiler.decorate

rbi = Tapioca::DEFAULT_RBI_FORMATTER.print_file(file)
Expand Down
6 changes: 6 additions & 0 deletions spec/spec_with_project.rb
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def assert_project_file_equal(path, expected)
assert_equal(expected, @project.read(path))
end

# Assert that the contents of `path` inside `@project` includes `expected`
sig { params(path: String, expected: String).void }
def assert_project_file_includes(path, expected)
assert_includes(@project.read(path), expected)
end

# Assert that `path` exists inside `@project`
sig { params(path: String).void }
def assert_project_file_exist(path)
Expand Down
Loading
Loading