diff --git a/lib/tapioca/dsl/compilers/protobuf.rb b/lib/tapioca/dsl/compilers/protobuf.rb index 79f2ce83d..3a33ac40e 100644 --- a/lib/tapioca/dsl/compilers/protobuf.rb +++ b/lib/tapioca/dsl/compilers/protobuf.rb @@ -90,23 +90,51 @@ def decorate elsif constant == Google::Protobuf::Map create_type_members(klass, "Key", "Value") else - descriptor = T.let(T.unsafe(constant).descriptor, Google::Protobuf::Descriptor) - descriptor.each_oneof { |oneof| create_oneof_method(klass, oneof) } - fields = descriptor.map { |desc| create_descriptor_method(klass, desc) } - fields.sort_by!(&:name) + descriptor = T.unsafe(constant).descriptor - parameters = fields.map do |field| - create_kw_opt_param(field.name, type: field.init_type, default: field.default) - end + case descriptor + when Google::Protobuf::EnumDescriptor + descriptor.to_h.each do |sym, val| + klass.create_constant(sym.to_s, value: val.to_s) + end + + klass.create_method( + "lookup", + parameters: [create_param("number", type: "Integer")], + return_type: "T.nilable(Symbol)", + class_method: true, + ) + klass.create_method( + "resolve", + parameters: [create_param("symbol", type: "Symbol")], + return_type: "T.nilable(Integer)", + class_method: true, + ) + klass.create_method( + "descriptor", + return_type: "Google::Protobuf::EnumDescriptor", + class_method: true, + ) + when Google::Protobuf::Descriptor + descriptor.each_oneof { |oneof| create_oneof_method(klass, oneof) } + fields = descriptor.map { |desc| create_descriptor_method(klass, desc) } + fields.sort_by!(&:name) - if fields.all? { |field| FIELD_RE.match?(field.name) } - klass.create_method("initialize", parameters: parameters, return_type: "void") + parameters = fields.map do |field| + create_kw_opt_param(field.name, type: field.init_type, default: field.default) + end + + if fields.all? { |field| FIELD_RE.match?(field.name) } + klass.create_method("initialize", parameters: parameters, return_type: "void") + else + # One of the fields has an incorrect name for a named parameter so creating the default initialize for + # it would create a RBI with a syntax error. + # The workaround is to create an initialize that takes a **kwargs instead. + kwargs_parameter = create_kw_rest_param("fields", type: "T.untyped") + klass.create_method("initialize", parameters: [kwargs_parameter], return_type: "void") + end else - # One of the fields has an incorrect name for a named parameter so creating the default initialize for - # it would create a RBI with a syntax error. - # The workaround is to create an initialize that takes a **kwargs instead. - kwargs_parameter = create_kw_rest_param("fields", type: "T.untyped") - klass.create_method("initialize", parameters: [kwargs_parameter], return_type: "void") + raise TypeError, "Unexpected descriptor class: #{descriptor.class.name}" end end end @@ -118,7 +146,12 @@ class << self sig { override.returns(T::Enumerable[Module]) } def gather_constants marker = Google::Protobuf::MessageExts::ClassMethods - results = T.cast(ObjectSpace.each_object(marker).to_a, T::Array[Module]) + + enum_modules = ObjectSpace.each_object(Google::Protobuf::EnumDescriptor).map do |desc| + T.cast(desc, Google::Protobuf::EnumDescriptor).enummodule + end + + results = T.cast(ObjectSpace.each_object(marker).to_a, T::Array[Module]).concat(enum_modules) results.any? ? results + [Google::Protobuf::RepeatedField, Google::Protobuf::Map] : [] end end @@ -142,7 +175,13 @@ def create_type_members(klass, *names) def type_of(descriptor) case descriptor.type when :enum - descriptor.subtype.enummodule.name + # According to https://developers.google.com/protocol-buffers/docs/reference/ruby-generated#enum + # > You may assign either a number or a symbol to an enum field. + # > When reading the value back, it will be a symbol if the enum + # > value is known, or a number if it is unknown. Since proto3 uses + # > open enum semantics, any number may be assigned to an enum + # > field, even if it was not defined in the enum. + "T.any(Symbol, Integer)" when :message descriptor.subtype.msgclass.name when :int32, :int64, :uint32, :uint64 @@ -183,7 +222,7 @@ def field_of(descriptor) Field.new( name: descriptor.name, type: type, - init_type: "T.any(#{type}, T::Hash[#{key_type}, #{value_type}])", + init_type: "T.nilable(T.any(#{type}, T::Hash[#{key_type}, #{value_type}]))", default: "Google::Protobuf::Map.new(#{default_args.join(", ")})" ) else @@ -196,18 +235,19 @@ def field_of(descriptor) Field.new( name: descriptor.name, type: type, - init_type: "T.any(#{type}, T::Array[#{elem_type}])", + init_type: "T.nilable(T.any(#{type}, T::Array[#{elem_type}]))", default: "Google::Protobuf::RepeatedField.new(#{default_args.join(", ")})" ) end else type = type_of(descriptor) - type = as_nilable_type(type) if nilable_descriptor?(descriptor) + nilable_type = as_nilable_type(type) + type = nilable_type if nilable_descriptor?(descriptor) Field.new( name: descriptor.name, type: type, - init_type: type, + init_type: nilable_type, default: "nil" ) end @@ -230,7 +270,7 @@ def create_descriptor_method(klass, desc) klass.create_method( "#{field.name}=", parameters: [create_param("value", type: field.type)], - return_type: field.type + return_type: "void" ) field diff --git a/spec/tapioca/dsl/compilers/protobuf_spec.rb b/spec/tapioca/dsl/compilers/protobuf_spec.rb index c6b973ac2..174d541a8 100644 --- a/spec/tapioca/dsl/compilers/protobuf_spec.rb +++ b/spec/tapioca/dsl/compilers/protobuf_spec.rb @@ -57,19 +57,19 @@ class ProtobufSpec < ::DslSpec # typed: strong class Cart - sig { params(customer_id: Integer, shop_id: Integer).void } + sig { params(customer_id: T.nilable(Integer), shop_id: T.nilable(Integer)).void } def initialize(customer_id: nil, shop_id: nil); end sig { returns(Integer) } def customer_id; end - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def customer_id=(value); end sig { returns(Integer) } def shop_id; end - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def shop_id=(value); end end RBI @@ -94,13 +94,13 @@ def shop_id=(value); end # typed: strong class Cart - sig { params(events: String).void } + sig { params(events: T.nilable(String)).void } def initialize(events: nil); end sig { returns(String) } def events; end - sig { params(value: String).returns(String) } + sig { params(value: String).void } def events=(value); end end RBI @@ -134,7 +134,7 @@ def initialize(cart_item_index: nil); end sig { returns(T.nilable(Google::Protobuf::UInt64Value)) } def cart_item_index; end - sig { params(value: T.nilable(Google::Protobuf::UInt64Value)).returns(T.nilable(Google::Protobuf::UInt64Value)) } + sig { params(value: T.nilable(Google::Protobuf::UInt64Value)).void } def cart_item_index=(value); end end RBI @@ -168,17 +168,39 @@ def cart_item_index=(value); end # typed: strong class Cart - sig { params(value_type: Cart::VALUE_TYPE).void } + sig { params(value_type: T.nilable(T.any(Symbol, Integer))).void } def initialize(value_type: nil); end - sig { returns(Cart::VALUE_TYPE) } + sig { returns(T.any(Symbol, Integer)) } def value_type; end - sig { params(value: Cart::VALUE_TYPE).returns(Cart::VALUE_TYPE) } + sig { params(value: T.any(Symbol, Integer)).void } def value_type=(value); end end RBI + expected_enum_rbi = <<~RBI + # typed: strong + + module Cart::VALUE_TYPE + class << self + sig { returns(Google::Protobuf::EnumDescriptor) } + def descriptor; end + + sig { params(number: Integer).returns(T.nilable(Symbol)) } + def lookup(number); end + + sig { params(symbol: Symbol).returns(T.nilable(Integer)) } + def resolve(symbol); end + end + + FIXED_AMOUNT = 1 + NULL = 0 + PERCENTAGE = 2 + end + RBI + + assert_equal(expected_enum_rbi, rbi_for("Cart::VALUE_TYPE")) assert_equal(expected, rbi_for(:Cart)) end @@ -208,13 +230,13 @@ def value_type=(value); end # typed: strong class Cart - sig { params(value_type: Cart::MYVALUETYPE).void } + sig { params(value_type: T.nilable(T.any(Symbol, Integer))).void } def initialize(value_type: nil); end - sig { returns(Cart::MYVALUETYPE) } + sig { returns(T.any(Symbol, Integer)) } def value_type; end - sig { params(value: Cart::MYVALUETYPE).returns(Cart::MYVALUETYPE) } + sig { params(value: T.any(Symbol, Integer)).void } def value_type=(value); end end RBI @@ -242,19 +264,19 @@ def value_type=(value); end # typed: strong class Cart - sig { params(customer_ids: T.any(Google::Protobuf::RepeatedField[Integer], T::Array[Integer]), indices: T.any(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value], T::Array[Google::Protobuf::UInt64Value])).void } + sig { params(customer_ids: T.nilable(T.any(Google::Protobuf::RepeatedField[Integer], T::Array[Integer])), indices: T.nilable(T.any(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value], T::Array[Google::Protobuf::UInt64Value]))).void } def initialize(customer_ids: Google::Protobuf::RepeatedField.new(:int32), indices: Google::Protobuf::RepeatedField.new(:message, Google::Protobuf::UInt64Value)); end sig { returns(Google::Protobuf::RepeatedField[Integer]) } def customer_ids; end - sig { params(value: Google::Protobuf::RepeatedField[Integer]).returns(Google::Protobuf::RepeatedField[Integer]) } + sig { params(value: Google::Protobuf::RepeatedField[Integer]).void } def customer_ids=(value); end sig { returns(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]) } def indices; end - sig { params(value: Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]).returns(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]) } + sig { params(value: Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]).void } def indices=(value); end end RBI @@ -282,19 +304,19 @@ def indices=(value); end # typed: strong class Cart - sig { params(customers: T.any(Google::Protobuf::Map[String, Integer], T::Hash[String, Integer]), stores: T.any(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value], T::Hash[String, Google::Protobuf::UInt64Value])).void } + sig { params(customers: T.nilable(T.any(Google::Protobuf::Map[String, Integer], T::Hash[String, Integer])), stores: T.nilable(T.any(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value], T::Hash[String, Google::Protobuf::UInt64Value]))).void } def initialize(customers: Google::Protobuf::Map.new(:string, :int32), stores: Google::Protobuf::Map.new(:string, :message, Google::Protobuf::UInt64Value)); end sig { returns(Google::Protobuf::Map[String, Integer]) } def customers; end - sig { params(value: Google::Protobuf::Map[String, Integer]).returns(Google::Protobuf::Map[String, Integer]) } + sig { params(value: Google::Protobuf::Map[String, Integer]).void } def customers=(value); end sig { returns(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]) } def stores; end - sig { params(value: Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]).returns(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]) } + sig { params(value: Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]).void } def stores=(value); end end RBI @@ -329,47 +351,47 @@ def stores=(value); end rbi_output = rbi_for(:Cart) assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: T::Boolean).returns(T::Boolean) } + sig { params(value: T::Boolean).void } def bool_value=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: String).returns(String) } + sig { params(value: String).void } def byte_value=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def customer_id=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def id=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def item_id=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: Float).returns(Float) } + sig { params(value: Float).void } def money_value=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: Float).returns(Float) } + sig { params(value: Float).void } def number_value=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def shop_id=(value); end RBI assert_includes(rbi_output, indented(<<~RBI, 2)) - sig { params(value: String).returns(String) } + sig { params(value: String).void } def string_value=(value); end RBI end @@ -398,13 +420,13 @@ def initialize(**fields); end sig { returns(Integer) } def ShopID; end - sig { params(value: Integer).returns(Integer) } + sig { params(value: Integer).void } def ShopID=(value); end sig { returns(String) } def ShopName; end - sig { params(value: String).returns(String) } + sig { params(value: String).void } def ShopName=(value); end end RBI