Skip to content

Commit

Permalink
Merge pull request #1888 from Shopify/uk-at-compiler-options
Browse files Browse the repository at this point in the history
Start passing in an option to `ActiveRecordColumns` compiler for how to generate column types
  • Loading branch information
paracycle committed Jun 4, 2024
2 parents 95e60bc + ca2eb8d commit dcfb7da
Show file tree
Hide file tree
Showing 13 changed files with 340 additions and 375 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,7 @@ dsl:
app_root: "."
halt_upon_load_error: true
skip_constant: []
compiler_options: {}
gem:
outdir: sorbet/rbi/gems
file_header: true
Expand Down
29 changes: 29 additions & 0 deletions lib/tapioca/cli.rb
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,20 @@ def todo
banner: "constant [constant ...]",
desc: "Do not generate RBI definitions for the given application constant(s)",
default: []
option :compiler_options,
type: :hash,
desc: "Options to pass to the DSL compilers",
hide: true,
default: {}
def dsl(*constant_or_paths)
set_environment(options)

# Assume anything starting with a capital letter or colon is a class, otherwise a path
constants, paths = constant_or_paths.partition { |c| c =~ /\A[A-Z:]/ }

# Make sure compiler options are received as a hash
compiler_options = process_compiler_options

command_args = {
requested_constants: constants,
requested_paths: paths.map { |p| Pathname.new(p) },
Expand All @@ -161,6 +169,7 @@ def dsl(*constant_or_paths)
rbi_formatter: rbi_formatter(options),
app_root: options[:app_root],
halt_upon_load_error: options[:halt_upon_load_error],
compiler_options: compiler_options,
}

command = if options[:verify]
Expand Down Expand Up @@ -372,6 +381,26 @@ def exit_on_failure?

private

def process_compiler_options
compiler_options = options[:compiler_options]

# Parse all compiler option hash values as YAML if they are Strings
compiler_options.transform_values! do |value|
value = YAML.safe_load(value) if String === value
value
rescue YAML::Exception
raise MalformattedArgumentError,
"Option '--compiler-options' should have well-formatted YAML strings, but received: '#{value}'"
end

unless compiler_options.values.all? { |v| Hash === v }
raise MalformattedArgumentError,
"Option '--compiler-options' should be a hash of hashes, but received: '#{compiler_options}'"
end

compiler_options
end

def print_init_next_steps
say(<<~OUTPUT)
#{set_color("This project is now set up for use with Sorbet and Tapioca", :bold)}
Expand Down
6 changes: 5 additions & 1 deletion lib/tapioca/commands/abstract_dsl.rb
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class AbstractDsl < CommandWithoutTracker
rbi_formatter: RBIFormatter,
app_root: String,
halt_upon_load_error: T::Boolean,
compiler_options: T::Hash[String, T.untyped],
).void
end
def initialize(
Expand All @@ -45,7 +46,8 @@ def initialize(
gem_dir: DEFAULT_GEM_DIR,
rbi_formatter: DEFAULT_RBI_FORMATTER,
app_root: ".",
halt_upon_load_error: true
halt_upon_load_error: true,
compiler_options: {}
)
@requested_constants = requested_constants
@requested_paths = requested_paths
Expand All @@ -63,6 +65,7 @@ def initialize(
@app_root = app_root
@halt_upon_load_error = halt_upon_load_error
@skip_constant = skip_constant
@compiler_options = compiler_options

super()
end
Expand Down Expand Up @@ -129,6 +132,7 @@ def create_pipeline
},
skipped_constants: constantize(@skip_constant, ignore_missing: true),
number_of_workers: @number_of_workers,
compiler_options: @compiler_options,
)
end

Expand Down
15 changes: 13 additions & 2 deletions lib/tapioca/dsl/compiler.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Compiler
sig { returns(RBI::Tree) }
attr_reader :root

sig { returns(T::Hash[String, T.untyped]) }
attr_reader :options

class << self
extend T::Sig

Expand Down Expand Up @@ -60,11 +63,19 @@ def all_modules
end
end

sig { params(pipeline: Tapioca::Dsl::Pipeline, root: RBI::Tree, constant: ConstantType).void }
def initialize(pipeline, root, constant)
sig do
params(
pipeline: Tapioca::Dsl::Pipeline,
root: RBI::Tree,
constant: ConstantType,
options: T::Hash[String, T.untyped],
).void
end
def initialize(pipeline, root, constant, options = {})
@pipeline = pipeline
@root = root
@constant = constant
@options = options
@errors = T.let([], T::Array[String])
end

Expand Down
15 changes: 14 additions & 1 deletion lib/tapioca/dsl/compilers/active_record_columns.rb
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,19 @@ def gather_constants

private

sig { returns(Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption) }
def column_type_option
@column_type_option ||= T.let(
Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption.from_serialized(
options.fetch(
"types",
Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption::Persisted.serialize,
),
),
T.nilable(Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption),
)
end

sig do
params(
klass: RBI::Scope,
Expand Down Expand Up @@ -174,7 +187,7 @@ def add_method(klass, name, methods_to_add, return_type: "void", parameters: [])
end
def add_methods_for_attribute(klass, attribute_name, column_name = attribute_name, methods_to_add = nil)
getter_type, setter_type = Helpers::ActiveRecordColumnTypeHelper
.new(constant)
.new(constant, column_type_option: column_type_option)
.type_for(attribute_name, column_name)

# Added by ActiveRecord::AttributeMethods::Read
Expand Down
5 changes: 4 additions & 1 deletion lib/tapioca/dsl/compilers/identity_cache.rb
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def create_index_fetch_by_methods(field, klass)
).void
end
def create_aliased_fetch_by_methods(field, klass)
type, _ = Helpers::ActiveRecordColumnTypeHelper.new(constant).type_for(field.alias_name.to_s)
type, _ = Helpers::ActiveRecordColumnTypeHelper.new(
constant,
column_type_option: Helpers::ActiveRecordColumnTypeHelper::ColumnTypeOption::Nilable,
).type_for(field.alias_name.to_s)
multi_type = type.delete_prefix("T.nilable(").delete_suffix(")").delete_prefix("::")
suffix = field.send(:fetch_method_suffix)

Expand Down
69 changes: 48 additions & 21 deletions lib/tapioca/dsl/helpers/active_record_column_type_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,47 @@ class ActiveRecordColumnTypeHelper
extend T::Sig
include RBIHelper

sig { params(constant: T.class_of(ActiveRecord::Base)).void }
def initialize(constant)
class ColumnTypeOption < T::Enum
extend T::Sig
enums do
Untyped = new("untyped")
Nilable = new("nilable")
Persisted = new("persisted")
end

sig { returns(T::Boolean) }
def persisted?
self == ColumnTypeOption::Persisted
end

sig { returns(T::Boolean) }
def nilable?
self == ColumnTypeOption::Nilable
end

sig { returns(T::Boolean) }
def untyped?
self == ColumnTypeOption::Untyped
end
end

sig do
params(
constant: T.class_of(ActiveRecord::Base),
column_type_option: ColumnTypeOption,
).void
end
def initialize(constant, column_type_option: ColumnTypeOption::Persisted)
@constant = constant
@column_type_option = column_type_option
end

sig { params(attribute_name: String, column_name: String).returns([String, String]) }
sig do
params(
attribute_name: String,
column_name: String,
).returns([String, String])
end
def type_for(attribute_name, column_name = attribute_name)
return id_type if attribute_name == "id"

Expand All @@ -27,15 +62,19 @@ def type_for(attribute_name, column_name = attribute_name)
sig { returns([String, String]) }
def id_type
if @constant.respond_to?(:composite_primary_key?) && T.unsafe(@constant).composite_primary_key?
@constant.primary_key.map(&method(:column_type_for)).map { |tuple| "[#{tuple.join(", ")}]" }
@constant.primary_key.map do |column|
column_type_for(column)
end.map do |tuple|
"[#{tuple.join(", ")}]"
end
else
column_type_for(@constant.primary_key)
end
end

sig { params(column_name: String).returns([String, String]) }
def column_type_for(column_name)
return ["T.untyped", "T.untyped"] if do_not_generate_strong_types?(@constant)
return ["T.untyped", "T.untyped"] if @column_type_option.untyped?

column = @constant.columns_hash[column_name]
column_type = @constant.attribute_types[column_name]
Expand All @@ -48,18 +87,12 @@ def column_type_for(column_name)
getter_type
end

if column&.null
if @column_type_option.persisted? && !column&.null
[getter_type, setter_type]
else
getter_type = as_nilable_type(getter_type) unless not_nilable_serialized_column?(column_type)
return [getter_type, as_nilable_type(setter_type)]
[getter_type, as_nilable_type(setter_type)]
end

if Array(@constant.primary_key).include?(column_name) ||
column_name == "created_at" ||
column_name == "updated_at"
getter_type = as_nilable_type(getter_type)
end

[getter_type, setter_type]
end

sig { params(column_type: T.untyped).returns(String) }
Expand Down Expand Up @@ -115,12 +148,6 @@ def type_for_activerecord_value(column_type)
end
end

sig { params(constant: Module).returns(T::Boolean) }
def do_not_generate_strong_types?(constant)
Object.const_defined?(:StrongTypeGeneration) &&
!(constant.singleton_class < Object.const_get(:StrongTypeGeneration))
end

sig { params(column_type: ActiveRecord::Enum::EnumType).returns(String) }
def enum_setter_type(column_type)
# In Rails < 7 this method is private. When support for that is dropped we can call the method directly
Expand Down
12 changes: 10 additions & 2 deletions lib/tapioca/dsl/pipeline.rb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Pipeline
error_handler: T.proc.params(error: String).void,
skipped_constants: T::Array[Module],
number_of_workers: T.nilable(Integer),
compiler_options: T::Hash[String, T::Hash[String, T.untyped]],
).void
end
def initialize(
Expand All @@ -42,7 +43,8 @@ def initialize(
excluded_compilers: [],
error_handler: $stderr.method(:puts).to_proc,
skipped_constants: [],
number_of_workers: nil
number_of_workers: nil,
compiler_options: {}
)
@active_compilers = T.let(
gather_active_compilers(requested_compilers, excluded_compilers),
Expand All @@ -53,6 +55,7 @@ def initialize(
@error_handler = error_handler
@skipped_constants = skipped_constants
@number_of_workers = number_of_workers
@compiler_options = compiler_options
@errors = T.let([], T::Array[String])
end

Expand Down Expand Up @@ -197,7 +200,12 @@ def rbi_for_constant(constant)
active_compilers.each do |compiler_class|
next unless compiler_class.handles?(constant)

compiler = compiler_class.new(self, file.root, constant)
compiler_key = T.must(compiler_class.name).dup
Tapioca::Dsl::Compilers::NAMESPACES.each do |namespace|
compiler_key.delete_prefix!(namespace)
end
options = @compiler_options.fetch(compiler_key, {})
compiler = compiler_class.new(self, file.root, constant, options)
compiler.decorate
rescue
$stderr.puts("Error: `#{compiler_class.name}` failed to generate RBI for `#{constant}`")
Expand Down
4 changes: 3 additions & 1 deletion lib/tapioca/helpers/config_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def validate_config_options(command_options, config_key, config_options)
when :hash
error_msg = "invalid value for option `#{config_option_key}` for key `#{config_key}` - expected " \
"`Hash[String, String]` but found `#{config_option_value}`"
all_strings = (config_option_value.keys + config_option_value.values).all? { |v| v.is_a?(String) }
values_to_validate = config_option_value.keys
values_to_validate += config_option_value.values unless config_option_key == "compiler_options"
all_strings = values_to_validate.all? { |v| v.is_a?(String) }
next build_error(error_msg) unless all_strings
end
end
Expand Down
22 changes: 16 additions & 6 deletions lib/tapioca/helpers/test/dsl_compiler.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ def activate_other_dsl_compilers(*compiler_classes)
context.activate_other_dsl_compilers(compiler_classes)
end

sig { params(constant_name: T.any(Symbol, String)).returns(String) }
def rbi_for(constant_name)
context.rbi_for(constant_name)
sig do
params(
constant_name: T.any(Symbol, String),
compiler_options: T::Hash[Symbol, T.untyped],
).returns(String)
end
def rbi_for(constant_name, compiler_options: {})
context.rbi_for(constant_name, compiler_options: compiler_options)
end

sig { returns(T::Array[String]) }
Expand Down Expand Up @@ -85,8 +90,13 @@ def gathered_constants
compiler_class.processable_constants.filter_map(&:name).sort
end

sig { params(constant_name: T.any(Symbol, String)).returns(String) }
def rbi_for(constant_name)
sig do
params(
constant_name: T.any(Symbol, String),
compiler_options: T::Hash[Symbol, T.untyped],
).returns(String)
end
def rbi_for(constant_name, compiler_options: {})
# Make sure this is a constant that we can handle.
unless gathered_constants.include?(constant_name.to_s)
raise "`#{constant_name}` is not processable by the `#{compiler_class}` compiler."
Expand All @@ -95,7 +105,7 @@ def rbi_for(constant_name)
file = RBI::File.new(strictness: "strong")
constant = Object.const_get(constant_name)

compiler = compiler_class.new(pipeline, file.root, constant)
compiler = compiler_class.new(pipeline, file.root, constant, compiler_options.transform_keys(&:to_s))
compiler.decorate

rbi = Tapioca::DEFAULT_RBI_FORMATTER.print_file(file)
Expand Down
6 changes: 6 additions & 0 deletions spec/spec_with_project.rb
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def assert_project_file_equal(path, expected)
assert_equal(expected, @project.read(path))
end

# Assert that the contents of `path` inside `@project` includes `expected`
sig { params(path: String, expected: String).void }
def assert_project_file_includes(path, expected)
assert_includes(@project.read(path), expected)
end

# Assert that `path` exists inside `@project`
sig { params(path: String).void }
def assert_project_file_exist(path)
Expand Down
Loading

0 comments on commit dcfb7da

Please sign in to comment.