Skip to content

Commit

Permalink
support clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
fukatani committed Nov 13, 2017
1 parent 406fc1b commit 0561611
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
33 changes: 33 additions & 0 deletions rgf/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def _cleanup():
os.remove(fn)


def _cleanup_partial(uuid):
if uuid not in _UUIDS:
return
model_glob = os.path.join(_TEMP_PATH, uuid + "*")
for fn in glob(model_glob):
os.remove(fn)
_UUIDS.remove(uuid)


def _get_temp_path():
"""
For test
"""
return _TEMP_PATH


def _sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))

Expand Down Expand Up @@ -683,6 +699,17 @@ def predict(self, X):
y = np.argmax(y, axis=1)
return np.asarray(list(self._classes_map.values()))[np.searchsorted(list(self._classes_map.keys()), y)]

def cleanup(self):
"""
Clean tempfile used by this model.
"""
if self._estimators is not None:
for est in self._estimators:
_cleanup_partial(est._file_prefix)

# No more able to predict without refitting.
self._fitted = False


class _RGFBinaryClassifier(BaseEstimator, ClassifierMixin):
"""
Expand Down Expand Up @@ -1208,3 +1235,9 @@ def __setstate__(self, state):
with open(self._latest_model_loc, 'wb') as fw:
fw.write(self.__dict__["model"])
del self.__dict__["model"]

def cleanup(self):
"""
Clean tempfile used by this model.
"""
_cleanup_partial(self._file_prefix)
21 changes: 20 additions & 1 deletion rgf/test/test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import glob
import os
import pickle
import unittest

Expand All @@ -11,7 +13,7 @@
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.validation import check_random_state

from rgf.sklearn import RGFClassifier, RGFRegressor, _cleanup
from rgf.sklearn import RGFClassifier, RGFRegressor, _cleanup, _get_temp_path


class TestRGFClassfier(unittest.TestCase):
Expand Down Expand Up @@ -241,6 +243,15 @@ def test_joblib_pickle(self):

np.testing.assert_allclose(y_pred1, y_pred2)

def test_cleanup(self):
clf = RGFClassifier()
clf.fit(self.X_train, self.y_train)
clf.cleanup()

for est in clf.estimators_:
glob_file = os.path.join(_get_temp_path(), est._file_prefix + "*")
self.assertFalse(glob.glob(glob_file))


class TestRGFRegressor(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -408,6 +419,14 @@ def test_joblib_pickle(self):

np.testing.assert_allclose(y_pred1, y_pred2)

def test_cleanup(self):
reg = RGFRegressor()
reg.fit(self.X_train, self.y_train)
reg.cleanup()

glob_file = os.path.join(_get_temp_path(), reg._file_prefix + "*")
self.assertFalse(glob.glob(glob_file))


if __name__ == '__main__':
unittest.main()

0 comments on commit 0561611

Please sign in to comment.