Skip to content

Commit

Permalink
Extend protobuf support for map and repeated
Browse files Browse the repository at this point in the history
Fields in protobuf can be marked as "repeated" or "map", which correspond to arrays and hashes, respectively. This commit extends our support for protobuf by correctly marking the types of these fields. It also adds an initialize method to each of the protobuf classes so that creating them can be typed.
  • Loading branch information
kddnewton committed Apr 5, 2021
1 parent 3e83b52 commit 4c654e0
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 17 deletions.
148 changes: 132 additions & 16 deletions lib/tapioca/compilers/dsl/protobuf.rb
Original file line number Diff line number Diff line change
Expand Up @@ -62,33 +62,99 @@ module Dsl
# end
# ~~~
class Protobuf < Base
# Parlour doesn't support type members out of the box, so adding the
# ability to do that here. This should be upstreamed.
class TypeMember < Parlour::RbiGenerator::RbiObject
extend T::Sig

sig { params(other: Object).returns(T::Boolean) }
def ==(other)
TypeMember === other && name == other.name
end

sig do
override
.params(indent_level: Integer, options: Parlour::RbiGenerator::Options)
.returns(T::Array[String])
end
def generate_rbi(indent_level, options)
[options.indented(indent_level, "#{name} = type_member")]
end

sig do
override
.params(others: T::Array[Parlour::RbiGenerator::RbiObject])
.returns(T::Boolean)
end
def mergeable?(others)
others.all? { |other| self == other }
end

sig { override.params(others: T::Array[Parlour::RbiGenerator::RbiObject]).void }
def merge_into_self(others); end

sig { override.returns(String) }
def describe
"Type Member (#{name})"
end
end

class Field < T::Struct
prop :name, String
prop :type, String
prop :init_type, String
prop :default, String

extend T::Sig

sig { returns(Parlour::RbiGenerator::Parameter) }
def to_init
Parlour::RbiGenerator::Parameter.new("#{name}:", type: init_type, default: default)
end
end

extend T::Sig

sig do
override.params(
root: Parlour::RbiGenerator::Namespace,
constant: T.class_of(Google::Protobuf::MessageExts)
constant: Module
).void
end
def decorate(root, constant)
descriptor = T.let(T.unsafe(constant).descriptor, Google::Protobuf::Descriptor)
return unless descriptor.any?

root.path(constant) do |klass|
descriptor.each do |desc|
create_descriptor_method(klass, desc)
if constant == Google::Protobuf::RepeatedField
create_type_members(klass, "Elem")
elsif constant == Google::Protobuf::Map
create_type_members(klass, "Key", "Value")
else
descriptor = T.let(T.unsafe(constant).descriptor, Google::Protobuf::Descriptor)
fields = descriptor.map { |desc| create_descriptor_method(klass, desc) }
fields.sort_by!(&:name)

create_method(klass, "initialize", parameters: fields.map!(&:to_init))
end
end
end

sig { override.returns(T::Enumerable[Module]) }
def gather_constants
classes = T.cast(ObjectSpace.each_object(Class), T::Enumerable[Class])
classes.select { |c| c < Google::Protobuf::MessageExts && !c.singleton_class? }
marker = Google::Protobuf::MessageExts::ClassMethods
results = T.cast(ObjectSpace.each_object(marker).to_a, T::Array[Module])
results.any? ? results + [Google::Protobuf::RepeatedField, Google::Protobuf::Map] : []
end

private

sig { params(klass: Parlour::RbiGenerator::Namespace, names: String).void }
def create_type_members(klass, *names)
klass.create_extend("T::Generic")

names.each do |name|
klass.children << TypeMember.new(klass.generator, name)
end
end

sig do
params(
descriptor: Google::Protobuf::FieldDescriptor
Expand All @@ -113,30 +179,80 @@ def type_of(descriptor)
end
end

sig { params(descriptor: Google::Protobuf::FieldDescriptor).returns(Field) }
def field_of(descriptor)
if descriptor.label == :repeated
# Here we're going to check if the submsg_name is named according to
# how Google names map entries.
# https://github.com/protocolbuffers/protobuf/blob/f82e26/ruby/ext/google/protobuf_c/defs.c#L1963-L1966
if descriptor.submsg_name.to_s.end_with?("_MapEntry_#{descriptor.name}")
key = descriptor.subtype.lookup('key')
value = descriptor.subtype.lookup('value')

key_type = type_of(key)
value_type = type_of(value)
type = "Google::Protobuf::Map[#{key_type}, #{value_type}]"

default_args = [key.type.inspect, value.type.inspect]
default_args << value_type if %i[enum message].include?(value.type)

Field.new(
name: descriptor.name,
type: type,
init_type: "T.any(#{type}, T::Hash[#{key_type}, #{value_type}])",
default: "Google::Protobuf::Map.new(#{default_args.join(', ')})"
)
else
elem_type = type_of(descriptor)
type = "Google::Protobuf::RepeatedField[#{elem_type}]"

default_args = [descriptor.type.inspect]
default_args << elem_type if %i[enum message].include?(descriptor.type)

Field.new(
name: descriptor.name,
type: type,
init_type: "T.any(#{type}, T::Array[#{elem_type}])",
default: "Google::Protobuf::RepeatedField.new(#{default_args.join(', ')})"
)
end
else
type = type_of(descriptor)

Field.new(
name: descriptor.name,
type: type,
init_type: type,
default: "nil"
)
end
end

sig do
params(
klass: Parlour::RbiGenerator::Namespace,
desc: Google::Protobuf::FieldDescriptor,
).void
).returns(Field)
end
def create_descriptor_method(klass, desc)
name = desc.name
type = type_of(desc)
field = field_of(desc)

create_method(
klass,
name,
return_type: type
field.name,
return_type: field.type
)

create_method(
klass,
"#{name}=",
"#{field.name}=",
parameters: [
Parlour::RbiGenerator::Parameter.new("value", type: type),
Parlour::RbiGenerator::Parameter.new("value", type: field.type),
],
return_type: type
return_type: field.type
)

field
end
end
end
Expand Down
98 changes: 97 additions & 1 deletion spec/tapioca/compilers/dsl/protobuf_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class Tapioca::Compilers::Dsl::ProtobufSpec < DslSpec
Cart = Google::Protobuf::DescriptorPool.generated_pool.lookup("MyCart").msgclass
RUBY

assert_equal(["Cart"], gathered_constants)
assert_equal(
["Cart", "Google::Protobuf::Map", "Google::Protobuf::RepeatedField"],
gathered_constants
)
end
end

Expand Down Expand Up @@ -56,6 +59,9 @@ def customer_id; end
sig { params(value: Integer).returns(Integer) }
def customer_id=(value); end
sig { params(customer_id: Integer, shop_id: Integer).void }
def initialize(customer_id: nil, shop_id: nil); end
sig { returns(Integer) }
def shop_id; end
Expand Down Expand Up @@ -88,6 +94,9 @@ def events; end
sig { params(value: String).returns(String) }
def events=(value); end
sig { params(events: String).void }
def initialize(events: nil); end
end
RBI

Expand Down Expand Up @@ -118,6 +127,9 @@ def cart_item_index; end
sig { params(value: Google::Protobuf::UInt64Value).returns(Google::Protobuf::UInt64Value) }
def cart_item_index=(value); end
sig { params(cart_item_index: Google::Protobuf::UInt64Value).void }
def initialize(cart_item_index: nil); end
end
RBI

Expand Down Expand Up @@ -149,6 +161,9 @@ def cart_item_index=(value); end
expected = <<~RBI
# typed: strong
class Cart
sig { params(value_type: Cart::VALUE_TYPE).void }
def initialize(value_type: nil); end
sig { returns(Cart::VALUE_TYPE) }
def value_type; end
Expand Down Expand Up @@ -185,6 +200,9 @@ def value_type=(value); end
expected = <<~RBI
# typed: strong
class Cart
sig { params(value_type: Cart::MYVALUETYPE).void }
def initialize(value_type: nil); end
sig { returns(Cart::MYVALUETYPE) }
def value_type; end
Expand All @@ -196,6 +214,84 @@ def value_type=(value); end
assert_equal(expected, rbi_for(:Cart))
end

it("generates methods in RBI files for repeated fields in Protobufs") do
add_ruby_file("protobuf.rb", <<~RUBY)
require 'google/protobuf/wrappers_pb'
Google::Protobuf::DescriptorPool.generated_pool.build do
add_file("cart.proto", :syntax => :proto3) do
add_message "MyCart" do
repeated :customer_ids, :int32, 1
repeated :indices, :message, 2, "google.protobuf.UInt64Value"
end
end
end
Cart = Google::Protobuf::DescriptorPool.generated_pool.lookup("MyCart").msgclass
RUBY

expected = <<~RBI
# typed: strong
class Cart
sig { returns(Google::Protobuf::RepeatedField[Integer]) }
def customer_ids; end
sig { params(value: Google::Protobuf::RepeatedField[Integer]).returns(Google::Protobuf::RepeatedField[Integer]) }
def customer_ids=(value); end
sig { returns(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]) }
def indices; end
sig { params(value: Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]).returns(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value]) }
def indices=(value); end
sig { params(customer_ids: T.any(Google::Protobuf::RepeatedField[Integer], T::Array[Integer]), indices: T.any(Google::Protobuf::RepeatedField[Google::Protobuf::UInt64Value], T::Array[Google::Protobuf::UInt64Value])).void }
def initialize(customer_ids: Google::Protobuf::RepeatedField.new(:int32), indices: Google::Protobuf::RepeatedField.new(:message, Google::Protobuf::UInt64Value)); end
end
RBI

assert_equal(expected, rbi_for(:Cart))
end

it("generates methods in RBI files for map fields in Protobufs") do
add_ruby_file("protobuf.rb", <<~RUBY)
require 'google/protobuf/wrappers_pb'
Google::Protobuf::DescriptorPool.generated_pool.build do
add_file("cart.proto", :syntax => :proto3) do
add_message "MyCart" do
map :customers, :string, :int32, 1
map :stores, :string, :message, 2, "google.protobuf.UInt64Value"
end
end
end
Cart = Google::Protobuf::DescriptorPool.generated_pool.lookup("MyCart").msgclass
RUBY

expected = <<~RBI
# typed: strong
class Cart
sig { returns(Google::Protobuf::Map[String, Integer]) }
def customers; end
sig { params(value: Google::Protobuf::Map[String, Integer]).returns(Google::Protobuf::Map[String, Integer]) }
def customers=(value); end
sig { params(customers: T.any(Google::Protobuf::Map[String, Integer], T::Hash[String, Integer]), stores: T.any(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value], T::Hash[String, Google::Protobuf::UInt64Value])).void }
def initialize(customers: Google::Protobuf::Map.new(:string, :int32), stores: Google::Protobuf::Map.new(:string, :message, Google::Protobuf::UInt64Value)); end
sig { returns(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]) }
def stores; end
sig { params(value: Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]).returns(Google::Protobuf::Map[String, Google::Protobuf::UInt64Value]) }
def stores=(value); end
end
RBI

assert_equal(expected, rbi_for(:Cart))
end

it("generates methods in RBI files for classes with Protobuf with all types") do
add_ruby_file("protobuf.rb", <<~RUBY)
require 'google/protobuf/timestamp_pb'
Expand Down

0 comments on commit 4c654e0

Please sign in to comment.