Skip to content

Commit

Permalink
Less indirection and fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 27, 2019
1 parent 9114ebb commit 45fd9ae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
22 changes: 11 additions & 11 deletions lib/fasttext/classifier.rb
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ def fit(x, y = nil)

# TODO predict multiple in C++ for performance
def predict(text, k: 1, threshold: 0.0)
if text.is_a?(Array)
text.map { |t| predict_one(t, k: k, threshold: threshold) }
else
predict_one(text, k: k, threshold: threshold)
end
multiple = text.is_a?(Array)
text = [text] unless multiple

result =
text.map do |t|
m.predict(prep_text(t), k, threshold).map do |v|
[remove_prefix(v[1]), v[0]]
end.to_h
end

multiple ? result : result.first
end

def test(x, y = nil, k: 1)
Expand Down Expand Up @@ -66,12 +72,6 @@ def labels(include_freq: false)

private

def predict_one(text, k:, threshold:)
m.predict(prep_text(text), k, threshold).map do |v|
[remove_prefix(v[1]), v[0]]
end.to_h
end

def input_path(x, y)
if x.is_a?(String)
raise ArgumentError, "Cannot pass y with file" if y
Expand Down
2 changes: 1 addition & 1 deletion test/classifier_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_works
assert model.sentence_vector("first document")

assert model.predict("First document")
asert model.predict(["First document", "Second document"], k: 3)
assert model.predict(["First document", "Second document"], k: 3)

# TODO fix flaky test
# pred = model.predict("First document").first
Expand Down

0 comments on commit 45fd9ae

Please sign in to comment.