Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

returned support for sparse matrixes (which was depricated) #310

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions python-package/xlearn/_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import warnings
import numpy as np
import scipy

from .xlearn import create_linear, create_fm, create_ffm
from .data import DMatrix
Expand Down Expand Up @@ -209,6 +210,8 @@ def fit(self, X, y=None, fields=None,
self._XLearnModel = create_ffm()
else:
raise Exception('model_type must be fm, ffm or lr')

temp_train_file = tempfile.NamedTemporaryFile(delete=True)

if y is None:
assert isinstance(X, str), 'X must be a string specifying training file location' \
Expand All @@ -223,7 +226,7 @@ def fit(self, X, y=None, fields=None,
self.fields = fields

# convert data into libsvm/libffm format for training
train_set = DMatrix(X, y, self.fields)
train_set = self._apply_correct_transformation(X, y, temp_train_file.name, fields=self.fields)
self._XLearnModel.setTrain(train_set)

# TODO: find out what task need to set sigmoid
Expand Down Expand Up @@ -252,14 +255,14 @@ def fit(self, X, y=None, fields=None,
else:
if not (isinstance(eval_set, list) and len(eval_set) == 2):
raise Exception('eval_set must be a 2-element list')

temp_val_file = tempfile.NamedTemporaryFile(delete=True)
# extract validation data
X_val, y_val = check_X_y(eval_set[0], eval_set[1],
accept_sparse=['csr'],
y_numeric=True,
multi_output=False)

validate_set = DMatrix(X_val, y_val, self.fields)
validate_set = self._apply_correct_transformation(X_val, y_val, temp_val_file.name, fields=self.fields)
self._XLearnModel.setValidate(validate_set)

# set up files for storing weights
Expand All @@ -270,6 +273,9 @@ def fit(self, X, y=None, fields=None,

# acquire weights
self._parse_weight(self._temp_weight_file.name)

# remove temporary files for training
self._remove_temp_file(temp_train_file)

def predict(self, X):
""" Generate prediction using feature matrix X
Expand All @@ -278,17 +284,21 @@ def predict(self, X):
Feature matrix
:return: prediction
"""


temp_test_file = tempfile.NamedTemporaryFile(delete=True)

if isinstance(X, str):
self._XLearnModel.setTest(X)
else:
X = check_array(X, accept_sparse=['csr'])
test_set = DMatrix(X, None, self.fields)
test_set = self._apply_correct_transformation(X, None, temp_test_file.name, fields = self.fields)
self._XLearnModel.setTest(test_set)

# generate output
pred = self.get_model().predict(self._temp_model_file.name)

# remove temporary test data
self._remove_temp_file(temp_test_file)

return pred

def feature_importance_(self):
Expand Down Expand Up @@ -318,6 +328,15 @@ def _convert_data(self, X, y, filepath, fields=None):
except:
raise Exception('Failed to convert feature matrix X and label y to xlearn data format')

def _apply_correct_transformation(self, X, y, filepath, fields=None):
#check type of object and apply correct transformation
if (scipy.sparse.issparse(X)):
self._convert_data(X, y, filepath, fields = fields)
return filepath
else:
coverted_data = DMatrix(X, y, fields)
return coverted_data

def _parse_weight(self, file_name):
# estimate number of features from txt file
num_lines = sum(1 for line in open(file_name))
Expand Down