Skip to content

Commit

Permalink
Rewrite as an extension
Browse files Browse the repository at this point in the history
The initial intention of using FFI was to make it compatible with
Rubinius and JRuby just in case we need to use these platforms.

However, Rubinius is no longer maintained, and TruffleRuby is becoming a
promising alternative to JRuby, and it supports extensions. It's time to
ditch FFI.

This fixes the incompatibility with the latest RubyGems which was
discussed at:
rubygems/rubygems#6218
  • Loading branch information
akihikodaki committed Feb 21, 2023
1 parent c183b6d commit 59f3ae9
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 252 deletions.
8 changes: 0 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,11 @@ FreeBSD port is available as `rubygem-cld3` in `textproc` category.

https://svnweb.freebsd.org/ports/head/textproc/rubygem-cld3/

#### JRuby
JRuby has a bug which prevents the feature detection. Apply the following
change:
https://github.com/jruby/jruby/pull/4118/commits/edad375ef4dcf195b19ce0afe4befac66468c736

### Troubleshooting
`gem install cld3` triggers native library building. If it fails, it is likely
that some required facilities are missing. Make sure C++ compiler is installed.
I recommend [GCC](https://gcc.gnu.org/) as a C++ compiler.

Runtime errors are likely to be issues of [FFI](https://github.com/ffi/ffi) or
programming errors. Make sure they are all correct.

If you cannot identify the cause of your problem, run spec of this library and
see whether the problem is reproducible with it or not. Spec is not included in
the gem, so clone the source code repository and then run `rake spec`.
Expand Down
5 changes: 2 additions & 3 deletions Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ int_path = FileList[
"ext/cld3/cld_3/protos/sentence.pb.h",
"ext/cld3/cld_3/protos/task_spec.pb.h",
"ext/cld3/extconf.rb",
"ext/cld3/libcld3.def",
"ext/cld3/cld3_ext.def",
"ext/cld3/nnet_language_identifier_c.cc",
"lib/cld3/unstable.rb",
"lib/cld3.rb",
"sig/cld3.rbs",
"spec/cld3_spec.rb"
Expand All @@ -109,7 +108,7 @@ task :default => :package

desc "Run the tests"
task "spec" => "intermediate/ext/cld3/Makefile" do
sh "make -C intermediate/ext/cld3"
sh "make -C intermediate/ext/cld3 install sitearchdir=../../lib sitelibdir=../../lib"
sh "cd intermediate && bundle exec rspec"
end

Expand Down
1 change: 0 additions & 1 deletion cld3.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Gem::Specification.new do |gem|
gem.author = "Akihiko Odaki"
gem.email = "akihiko.odaki@gmail.com"
gem.required_ruby_version = [ ">= 2.7.0", "< 3.3.0" ]
gem.add_dependency "ffi", [ ">= 1.1.0", "< 1.16.0" ]
gem.add_development_dependency "rbs", [ ">= 2.6.0", "< 2.7.0" ]
gem.add_development_dependency "rspec", [ ">= 3.11.0", "< 3.12.0" ]
gem.add_development_dependency "steep", [ ">= 1.0.0", "< 1.1.0" ]
Expand Down
2 changes: 2 additions & 0 deletions ext/cld3/cld3_ext.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
EXPORTS
Init_cld3_ext
3 changes: 1 addition & 2 deletions ext/cld3/extconf.rb
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ def ln_fallback(source, destination)
}

$CXXFLAGS += " -fvisibility=hidden -std=c++17"
$LIBRUBYARG = ""
create_makefile("libcld3")
create_makefile("cld3_ext")
8 changes: 0 additions & 8 deletions ext/cld3/libcld3.def

This file was deleted.

232 changes: 162 additions & 70 deletions ext/cld3/nnet_language_identifier_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <iostream>
#include <string>
#include <utility>
#include <ruby.h>
#include "nnet_language_identifier.h"

#if defined _WIN32 || defined __CYGWIN__
Expand All @@ -27,89 +28,180 @@ limitations under the License.
#endif

struct Result {
struct {
const char *data;
std::size_t size;
} language;
struct {
const chrome_lang_id::NNetLanguageIdentifier::SpanInfo *data;
std::size_t size;
} byte_ranges;
float probability;
float proportion;
bool is_reliable;
VALUE result_klass;
VALUE span_info_klass;
const chrome_lang_id::NNetLanguageIdentifier::Result& data;

VALUE convert() const {
if (data.language == chrome_lang_id::NNetLanguageIdentifier::kUnknown)
return Qnil;

VALUE byte_ranges = rb_ary_new2(data.byte_ranges.size());
for (auto& byte_range_data : data.byte_ranges) {
VALUE argv[] = {
INT2NUM(byte_range_data.start_index),
INT2NUM(byte_range_data.end_index),
DBL2NUM(byte_range_data.probability),
};

VALUE byte_range = rb_class_new_instance(sizeof(argv) / sizeof(*argv),
argv,
span_info_klass);
rb_ary_push(byte_ranges, byte_range);
}

VALUE argv[] = {
ID2SYM(rb_intern2(data.language.data(), data.language.size())),
DBL2NUM(data.probability),
data.is_reliable ? Qtrue : Qfalse,
DBL2NUM(data.proportion),
byte_ranges,
};

return rb_class_new_instance(sizeof(argv) / sizeof(*argv), argv,
result_klass);
}
};

struct OwningResult {
OwningResult(chrome_lang_id::NNetLanguageIdentifier::Result&& result) {
references.language = std::move(result.language);
references.byte_ranges = std::move(result.byte_ranges);
plain.language.data = references.language.data();
plain.language.size = references.language.size();
plain.byte_ranges.data = references.byte_ranges.data();
plain.byte_ranges.size = references.byte_ranges.size();
plain.probability = result.probability;
plain.proportion = result.proportion;
plain.is_reliable = result.is_reliable;
struct ResultVector {
VALUE result_klass;
VALUE span_info_klass;
VALUE buffer;
const std::vector<chrome_lang_id::NNetLanguageIdentifier::Result>& data;

VALUE convert() const {
for (auto& element_data : data) {
Result result { result_klass, span_info_klass, element_data };
VALUE element = result.convert();
if (element == Qnil)
break;

rb_ary_push(buffer, element);
}

return buffer;
}
};

Result plain;
struct {
std::string language;
std::vector<chrome_lang_id::NNetLanguageIdentifier::SpanInfo> byte_ranges;
} references;
template<typename T>
VALUE convert_protected(VALUE arg)
{
auto result = reinterpret_cast<const T *>(arg);
return result->convert();
}

static void dfree(void *arg) {
auto data = static_cast<chrome_lang_id::NNetLanguageIdentifier *>(arg);
data->~NNetLanguageIdentifier();
xfree(arg);
}

static size_t dsize(const void *data) {
return sizeof(chrome_lang_id::NNetLanguageIdentifier);
}

static const rb_data_type_t data_type = {
.wrap_struct_name = "CLD3::NNetLanguageIdentifier",
.function = {
.dfree = dfree,
.dsize = dsize,
},
.flags = RUBY_TYPED_FREE_IMMEDIATELY
};

extern "C" {
EXPORT OwningResult *NNetLanguageIdentifier_find_language(
chrome_lang_id::NNetLanguageIdentifier *instance,
const char *data,
std::size_t size) {
return new OwningResult(instance->FindLanguage(std::string(data, size)));
static VALUE find_language(VALUE obj,
VALUE result_klass, VALUE span_info_klass,
VALUE text) {
int state;
VALUE converted;

{
chrome_lang_id::NNetLanguageIdentifier *data;
TypedData_Get_Struct(obj, chrome_lang_id::NNetLanguageIdentifier,
&data_type, data);
std::string text_string = std::string(RSTRING_PTR(text), RSTRING_LEN(text));
auto result_data = data->FindLanguage(text_string);
Result result { result_klass, span_info_klass, result_data };

converted = rb_protect(convert_protected<Result>,
reinterpret_cast<VALUE>(&result),
&state);
}

EXPORT std::vector<chrome_lang_id::NNetLanguageIdentifier::Result>*
NNetLanguageIdentifier_find_top_n_most_freq_langs(
chrome_lang_id::NNetLanguageIdentifier *instance,
const char *data, std::size_t size, int num_langs) {
std::string text(data, size);
return new auto(instance->FindTopNMostFreqLangs(text, num_langs));
}
if (state)
rb_jump_tag(state);

EXPORT void delete_NNetLanguageIdentifier(
chrome_lang_id::NNetLanguageIdentifier *pointer) {
delete pointer;
}
return converted;
}

EXPORT void delete_result(OwningResult *pointer) {
delete pointer;
static VALUE find_top_n_most_freq_langs(VALUE obj,
VALUE result_klass,
VALUE span_info_klass,
VALUE text,
VALUE num_langs) {
int state;
VALUE converted;

{
chrome_lang_id::NNetLanguageIdentifier *data;
TypedData_Get_Struct(obj, chrome_lang_id::NNetLanguageIdentifier,
&data_type, data);
VALUE buffer = rb_ary_new2(NUM2INT(num_langs));
std::string text_string = std::string(RSTRING_PTR(text), RSTRING_LEN(text));
auto result_data = data->FindTopNMostFreqLangs(text_string, num_langs);
ResultVector result { result_klass, span_info_klass, buffer, result_data };

converted = rb_protect(convert_protected<ResultVector>,
reinterpret_cast<VALUE>(&result),
&state);
}

EXPORT void delete_results(
std::vector<chrome_lang_id::NNetLanguageIdentifier::Result> *pointer) {
delete pointer;
}
if (state)
rb_jump_tag(state);

EXPORT chrome_lang_id::NNetLanguageIdentifier *new_NNetLanguageIdentifier(
int min_num_bytes, int max_num_bytes) {
return new chrome_lang_id::NNetLanguageIdentifier(
min_num_bytes, max_num_bytes);
}
return converted;
}

static VALUE make(VALUE klass, VALUE min_num_bytes, VALUE max_num_bytes) {
int min_num_bytes_int = NUM2INT(min_num_bytes);
int max_num_bytes_int = NUM2INT(max_num_bytes);
chrome_lang_id::NNetLanguageIdentifier *data;
VALUE value = TypedData_Make_Struct(klass,
chrome_lang_id::NNetLanguageIdentifier,
&data_type, data);
new (data) chrome_lang_id::NNetLanguageIdentifier(min_num_bytes_int, max_num_bytes_int);
return value;
}

EXPORT Result refer_to_nth_result(
std::vector<chrome_lang_id::NNetLanguageIdentifier::Result> *results,
std::size_t index) {
Result c;
auto& cc = (*results)[index];

c.language.data = cc.language.data();
c.language.size = cc.language.size();
c.byte_ranges.data = cc.byte_ranges.data();
c.byte_ranges.size = cc.byte_ranges.size();
c.probability = cc.probability;
c.proportion = cc.proportion;
c.is_reliable = cc.is_reliable;

return c;
extern "C" EXPORT void Init_cld3_ext() {
VALUE cld3 = rb_const_get(rb_cObject, rb_intern("CLD3"));
VALUE identifier =
rb_const_get(cld3, rb_intern("NNetLanguageIdentifier"));
VALUE unstable = rb_const_get(identifier, rb_intern("Unstable"));
VALUE params = rb_const_get(cld3, rb_intern("TaskContextParams"));
VALUE language_names = rb_const_get(params, rb_intern("LANGUAGE_NAMES"));

rb_define_const(identifier, "MIN_NUM_BYTES_TO_CONSIDER",
INT2NUM(chrome_lang_id::NNetLanguageIdentifier::kMinNumBytesToConsider));
rb_define_const(identifier, "MAX_NUM_BYTES_TO_CONSIDER",
INT2NUM(chrome_lang_id::NNetLanguageIdentifier::kMaxNumBytesToConsider));
rb_define_const(identifier, "MAX_NUM_INPUT_BYTES_TO_CONSIDER",
INT2NUM(chrome_lang_id::NNetLanguageIdentifier::kMaxNumInputBytesToConsider));
rb_define_const(identifier, "RELIABILITY_THRESHOLD",
DBL2NUM(chrome_lang_id::NNetLanguageIdentifier::kReliabilityThreshold));
rb_define_const(identifier, "RELIABILITY_HR_BS_THRESHOLD",
DBL2NUM(chrome_lang_id::NNetLanguageIdentifier::kReliabilityHrBsThreshold));

rb_define_singleton_method(unstable, "make", make, 2);
rb_define_method(unstable, "find_language", find_language, 3);
rb_define_method(unstable, "find_top_n_most_freq_langs",
find_top_n_most_freq_langs, 4);

for (int i = 0; ; i++) {
const char *name = chrome_lang_id::TaskContextParams::language_names(i);
if (!name)
break;

rb_ary_push(language_names, ID2SYM(rb_intern(name)));
}
}
Loading

0 comments on commit 59f3ae9

Please sign in to comment.