Skip to content
Permalink
Browse files

allow non-strings as parameters for try_convert.

  • Loading branch information...
amueller committed Jan 12, 2015
1 parent daa6dcf commit 0f65fe0c03755676effe7ee0f10b570a8ded7f21
Showing with 9 additions and 7 deletions.
  1. +5 −1 vistrails/packages/sklearn/init.py
  2. +4 −6 vistrails/packages/sklearn/tests.py
@@ -12,6 +12,9 @@


def try_convert(input_string):
if not isinstance(input_string, basestring):
# already converted
return input_string
if input_string.isdigit():
return int(input_string)
try:
@@ -145,7 +148,8 @@ class TrainTestSplit(Module):

def compute(self):
X_train, X_test, y_train, y_test = \
train_test_split(self.get_input("data"), self.get_input("target"))
train_test_split(self.get_input("data"), self.get_input("target"),
test_size=try_convert(self.get_input("test_size")))
self.set_output("training_data", X_train)
self.set_output("training_target", y_train)
self.set_output("test_data", X_test)
@@ -33,10 +33,8 @@ def test_train_test_split(self):
# check that we can split the iris dataset
with intercept_results(TrainTestSplit, 'training_data', TrainTestSplit,
'training_target', TrainTestSplit, 'test_data',
TrainTestSplit, 'test_data') as (X_train,
y_train,
X_test,
y_test):
TrainTestSplit, 'test_target') as (
X_train, y_train, X_test, y_test):
self.assertFalse(execute(
[
('datasets|Iris', identifier, []),
@@ -50,8 +48,8 @@ def test_train_test_split(self):
))
X_train = np.vstack(X_train)
X_test = np.vstack(X_test)
y_train = np.vstack(y_train)
y_test = np.vstack(y_test)
y_train = np.hstack(y_train)
y_test = np.hstack(y_test)
self.assertEqual(X_train.shape, (100, 4))
self.assertEqual(X_test.shape, (50, 4))
self.assertEqual(y_train.shape, (100,))

0 comments on commit 0f65fe0

Please sign in to comment.
You can’t perform that action at this time.