diff --git a/python-package/xlearn/_sklearn.py b/python-package/xlearn/_sklearn.py index 8efafdd7..6acafc81 100644 --- a/python-package/xlearn/_sklearn.py +++ b/python-package/xlearn/_sklearn.py @@ -18,6 +18,7 @@ import tempfile import warnings import numpy as np +import scipy from .xlearn import create_linear, create_fm, create_ffm from .data import DMatrix @@ -209,6 +210,8 @@ def fit(self, X, y=None, fields=None, self._XLearnModel = create_ffm() else: raise Exception('model_type must be fm, ffm or lr') + + temp_train_file = tempfile.NamedTemporaryFile(delete=True) if y is None: assert isinstance(X, str), 'X must be a string specifying training file location' \ @@ -223,7 +226,7 @@ def fit(self, X, y=None, fields=None, self.fields = fields # convert data into libsvm/libffm format for training - train_set = DMatrix(X, y, self.fields) + train_set = self._apply_correct_transformation(X, y, temp_train_file.name, fields=self.fields) self._XLearnModel.setTrain(train_set) # TODO: find out what task need to set sigmoid @@ -252,14 +255,14 @@ def fit(self, X, y=None, fields=None, else: if not (isinstance(eval_set, list) and len(eval_set) == 2): raise Exception('eval_set must be a 2-element list') - + temp_val_file = tempfile.NamedTemporaryFile(delete=True) # extract validation data X_val, y_val = check_X_y(eval_set[0], eval_set[1], accept_sparse=['csr'], y_numeric=True, multi_output=False) - - validate_set = DMatrix(X_val, y_val, self.fields) + + validate_set = self._apply_correct_transformation(X_val, y_val, temp_val_file.name, fields=self.fields) self._XLearnModel.setValidate(validate_set) # set up files for storing weights @@ -270,6 +273,9 @@ def fit(self, X, y=None, fields=None, # acquire weights self._parse_weight(self._temp_weight_file.name) + + # remove temporary files for training + self._remove_temp_file(temp_train_file) def predict(self, X): """ Generate prediction using feature matrix X @@ -278,17 +284,21 @@ def predict(self, X): Feature matrix :return: prediction """ - + + temp_test_file = tempfile.NamedTemporaryFile(delete=True) + if isinstance(X, str): self._XLearnModel.setTest(X) else: X = check_array(X, accept_sparse=['csr']) - test_set = DMatrix(X, None, self.fields) + test_set = self._apply_correct_transformation(X, None, temp_test_file.name, fields = self.fields) self._XLearnModel.setTest(test_set) # generate output pred = self.get_model().predict(self._temp_model_file.name) - + # remove temporary test data + self._remove_temp_file(temp_test_file) + return pred def feature_importance_(self): @@ -318,6 +328,15 @@ def _convert_data(self, X, y, filepath, fields=None): except: raise Exception('Failed to convert feature matrix X and label y to xlearn data format') + def _apply_correct_transformation(self, X, y, filepath, fields=None): + #check type of object and apply correct transformation + if (scipy.sparse.issparse(X)): + self._convert_data(X, y, filepath, fields = fields) + return filepath + else: + coverted_data = DMatrix(X, y, fields) + return coverted_data + def _parse_weight(self, file_name): # estimate number of features from txt file num_lines = sum(1 for line in open(file_name))