Skip to content
This repository has been archived by the owner on May 8, 2023. It is now read-only.

Commit

Permalink
Added support for Numo
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 12, 2019
1 parent de4fa88 commit 450fd52
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 5 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ before_install:
- ./test/ci/install_xlearn.sh
- export LD_LIBRARY_PATH=$HOME/xlearn/$XLEARN_VERSION/build/lib:$LD_LIBRARY_PATH
cache:
bundler: true
directories:
- $HOME/xlearn
script: bundle exec rake test
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Added `cv` method
- Added `save_txt` method
- Added support for Numo

## 0.1.0

Expand Down
37 changes: 33 additions & 4 deletions lib/xlearn/dmatrix.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,30 @@ class DMatrix
def initialize(data, label: nil)
@handle = ::FFI::MemoryPointer.new(:pointer)

nrow = data.count
ncol = data.first.count
if matrix?(data)
nrow = data.row_count
ncol = data.column_count
flat_data = data.to_a.flatten
elsif daru?(data)
nrow, ncol = data.shape
flat_data = data.map_rows(&:to_a).flatten
elsif narray?(data)
nrow, ncol = data.shape
# TODO convert to SFloat and pass pointer
# for better performance
flat_data = data.flatten.to_a
else
nrow = data.count
ncol = data.first.count
flat_data = data.flatten
end

c_data = ::FFI::MemoryPointer.new(:float, nrow * ncol)
c_data.put_array_of_float(0, data.flatten)
c_data.put_array_of_float(0, flat_data)

if label
c_label = ::FFI::MemoryPointer.new(:float, nrow)
c_label.put_array_of_float(0, label)
c_label.put_array_of_float(0, label.to_a)
end

# TODO support this
Expand All @@ -31,5 +46,19 @@ def self.finalize(pointer)
# must use proc instead of stabby lambda
proc { FFI.XlearnDataFree(pointer) }
end

private

def matrix?(data)
defined?(Matrix) && data.is_a?(Matrix)
end

def daru?(data)
defined?(Daru::DataFrame) && data.is_a?(Daru::DataFrame)
end

def narray?(data)
defined?(Numo::NArray) && data.is_a?(Numo::NArray)
end
end
end
1 change: 1 addition & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Bundler.require(:default)
require "minitest/autorun"
require "minitest/pride"
require "numo/narray"

class Minitest::Test
def assert_elements_in_delta(expected, actual, delta = 0.001)
Expand Down
11 changes: 10 additions & 1 deletion test/xlearn_test.rb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
require_relative "test_helper"

class XLearnTest < Minitest::Test
def test_works
def test_linear
x = [[1, 2], [3, 4], [5, 6], [7, 8]]
y = [1, 2, 3, 4]

Expand All @@ -28,4 +28,13 @@ def test_files
model.fit(path, eval_set: path)
model.predict(path, out_path: "/tmp/output.txt")
end

def test_numo
x = Numo::DFloat.cast([[1, 2], [3, 4], [5, 6], [7, 8]])
y = Numo::DFloat.cast([1, 2, 3, 4])

model = XLearn::Linear.new(task: "reg")
model.fit(x, y, eval_set: [x, y])
assert_elements_in_delta y.to_a, model.predict(x), 0.1
end
end
1 change: 1 addition & 0 deletions xlearn.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Gem::Specification.new do |spec|
spec.add_development_dependency "bundler"
spec.add_development_dependency "rake"
spec.add_development_dependency "minitest", ">= 5"
spec.add_development_dependency "numo-narray"
end

0 comments on commit 450fd52

Please sign in to comment.