Skip to content

Commit

Permalink
fix: Add error message when passing unsigned integers to tovw (#4610)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Jun 8, 2023
1 parent 63090d6 commit 4a43d09
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/tests/test_sklearn_vw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
VWRegressor,
tovw,
VWMultiClassifier,
VWRegressor,
)


Expand Down Expand Up @@ -43,6 +44,14 @@ def test_tovw():
assert tovw(x=csr_matrix(x), y=y, sample_weight=w, convert_labels=True) == expected


def test_tovw_raises_for_uint():
X = pd.DataFrame({"a": [1]}, dtype="uint32")
y = pd.Series(np.zeros(1))

with pytest.raises(TypeError):
VWRegressor().fit(X, y)


class BaseVWTest:
estimator = None

Expand Down
4 changes: 4 additions & 0 deletions python/vowpalwabbit/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,10 @@ def tovw(x, y=None, sample_weight=None, convert_labels=False):
for row in rows:
for col in cols:
x[row, col] = INVALID_CHARS.sub(".", x[row, col])
elif x.dtype.kind == "u":
raise TypeError(
"tovw does not support unsigned integers. Please convert to signed integers. See issue: https://github.com/VowpalWabbit/vowpal_wabbit/issues/4609"
)

# convert input to svmlight format
s = io.BytesIO()
Expand Down

0 comments on commit 4a43d09

Please sign in to comment.