Skip to content

Commit

Permalink
Add deletion of model files, validation predictions and test predicti…
Browse files Browse the repository at this point in the history
…ons (#807)

* Add deletion of model files

* Add deletion of test and validation files

* remove test that verifies the deletion validation files

* PEP8

* Reverses logic: loop through the directory instead of list of candidates

* PEP8

* Correct error message

* data structure changes

* Improve readability

* rewrite AbstractEvaluator.file_output() without changing its functionality

* implement locks on AbstractEvaluator.file_output()

* bug fix

* simplify lock naming

* Adapt unittest

* fix AbstractEvaluatorTest

* Add some nosetest byproducts to .gitignore

* Fix FunctionsTest (test_train_evaluator.py)

* Fix a couple of unit tests from TestTrainEvaluator

* PEP8

* delete unnecessary line of code
  • Loading branch information
gui-miotto committed Apr 10, 2020
1 parent 3bff740 commit dc5537e
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 124 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -27,6 +27,8 @@ __pycache__
pip-log.txt

# Unit test / coverage reports
.noseids
nosetests.xml
htmlcov
.coverage
.tox
Expand Down
104 changes: 83 additions & 21 deletions autosklearn/ensemble_builder.py
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import pynisher
import lockfile
from sklearn.utils.validation import check_random_state

from autosklearn.util.backend import Backend
Expand Down Expand Up @@ -145,7 +146,11 @@ def __init__(
'.auto-sklearn',
'predictions_test',
)

self.dir_models = os.path.join(
self.backend.temporary_directory,
'.auto-sklearn',
'models',
)
logger_name = 'EnsembleBuilder(%d):%s' % (self.seed, self.dataset_name)
self.logger = get_logger(logger_name)
if max_keep_best == 1:
Expand Down Expand Up @@ -308,15 +313,15 @@ def read_ensemble_preds(self):
'predictions_ensemble_*_*_*.npy',
)

y_ens_files = glob.glob(pred_path)
self.y_ens_files = glob.glob(pred_path)
# no validation predictions so far -- no files
if len(y_ens_files) == 0:
if len(self.y_ens_files) == 0:
self.logger.debug("Found no prediction files on ensemble data set:"
" %s" % pred_path)
return False

n_read_files = 0
for y_ens_fn in sorted(y_ens_files):
for y_ens_fn in sorted(self.y_ens_files):

if self.read_at_most and n_read_files >= self.read_at_most:
# limit the number of files that will be read
Expand All @@ -340,7 +345,6 @@ def read_ensemble_preds(self):
"seed": _seed,
"num_run": _num_run,
"budget": _budget,
"deleted": False,
Y_ENSEMBLE: None,
Y_VALID: None,
Y_TEST: None,
Expand Down Expand Up @@ -408,7 +412,7 @@ def get_n_best_preds(self):
according to score on "ensemble set"
n: self.ensemble_nbest
Side effect: delete predictions of non-winning models
Side effect: delete predictions of non-candidate models
"""

# Sort by score - higher is better!
Expand Down Expand Up @@ -489,7 +493,7 @@ def get_n_best_preds(self):
# reduce to keys
sorted_keys = list(map(lambda x: x[0], sorted_keys))

# remove loaded predictions for non-winning models
# remove loaded predictions for non-candidate models
for k in sorted_keys[ensemble_n_best:]:
self.read_preds[k][Y_ENSEMBLE] = None
self.read_preds[k][Y_VALID] = None
Expand Down Expand Up @@ -736,25 +740,83 @@ def predict(self, set_: str,
# TODO: ADD saving of predictions on "ensemble data"

def _delete_non_candidate_models(self, candidates):
candidates = [os.path.split(cand)[1] for cand in candidates]
for model_path in self.read_preds.keys():
if self.read_preds[model_path]['deleted']:
# Loop through the files currently in the directory
for pred_path in self.y_ens_files:

# Do not delete candidates
if pred_path in candidates:
continue
match = self.model_fn_re.search(model_path)
_num_run = int(match.group(2))
# Do not remove the dummy prediction!
if _num_run == 1:

match = self.model_fn_re.search(pred_path)
_full_name = match.group(0)
_seed = match.group(1)
_num_run = match.group(2)
_budget = match.group(3)

# Do not delete the dummy prediction
if int(_num_run) == 1:
continue
model_file = os.path.split(model_path)[1]
if model_file not in candidates:

# Besides the prediction, we have to take care of three other files: model,
# validation and test.
model_name = '%s.%s.%s.model' % (_seed, _num_run, _budget)
model_path = os.path.join(self.dir_models, model_name)
pred_valid_name = 'predictions_valid' + _full_name
pred_valid_path = os.path.join(self.dir_valid, pred_valid_name)
pred_test_name = 'predictions_test' + _full_name
pred_test_path = os.path.join(self.dir_test, pred_test_name)

paths = [model_path, pred_path]
if os.path.exists(pred_valid_path):
paths.append(pred_valid_path)
if os.path.exists(pred_test_path):
paths.append(pred_test_path)

# Lets lock all the files "at once" to avoid weird race conditions. Also,
# we either delete all files of a model (model, prediction, validation
# and test), or delete none. This makes it easier to keep track of which
# models have indeed been correctly removed.
locks = [lockfile.LockFile(path) for path in paths]
try:
for lock in locks:
lock.acquire()
except Exception as e:
if isinstance(e, lockfile.AlreadyLocked):
# If the file is already locked, we deal with it later. Not a big deal
self.logger.info(
'Model %s is already locked. Skipping it for now.', model_name)
else:
# Other exceptions, however, should not occur.
# The message bellow is asserted in test_delete_non_candidate_models()
self.logger.error(
'Failed to lock model %s files due to error %s', model_name, e)
for lock in locks:
if lock.i_am_locking():
lock.release()
continue

# Delete files if model is not a candidate AND prediction is old. We check if
# the prediction is old to avoid deleting a model that hasn't been appreciated
# by self.get_n_best_preds() yet.
original_timestamp = self.read_preds[pred_path]['mtime_ens']
current_timestamp = os.path.getmtime(pred_path)
if current_timestamp == original_timestamp:
# The messages logged here are asserted in
# test_delete_non_candidate_models(). Edit with care.
try:
os.remove(model_path)
self.logger.info("Deleted file of non-candidate model %s", model_path)
self.read_preds[model_path]['deleted'] = True
for path in paths:
os.remove(path)
self.logger.info(
"Deleted files of non-candidate model %s", model_name)
except Exception as e:
self.logger.error(
'Failed to delete non-candidate model %s due to error %s',
model_path, e)
"Failed to delete files of non-candidate model %s due"
" to error %s", model_name, e)

# If we reached this point, all locks were done by this thread. So no need
# to check "lock.i_am_locking()" here.
for lock in locks:
lock.release()

def _read_np_fn(self, fp):
if self.precision is "16":
Expand Down
128 changes: 96 additions & 32 deletions autosklearn/evaluation/abstract_evaluator.py
@@ -1,7 +1,9 @@
import os
import time
import warnings
from collections import namedtuple

import lockfile
import numpy as np
from sklearn.dummy import DummyClassifier, DummyRegressor
from smac.tae.execute_ta_run import StatusType
Expand All @@ -27,6 +29,7 @@
'AbstractEvaluator'
]

WriteTask = namedtuple('WriteTask', ['lock', 'writer', 'args'])

class MyDummyClassifier(DummyClassifier):
def __init__(self, configuration, random_state, init_params=None):
Expand Down Expand Up @@ -142,7 +145,11 @@ def __init__(self, backend, queue, metric,

self.output_y_hat_optimization = output_y_hat_optimization
self.all_scoring_functions = all_scoring_functions
self.disable_file_output = disable_file_output

if isinstance(disable_file_output, (bool, list)):
self.disable_file_output = disable_file_output
else:
raise ValueError('disable_file_output should be either a bool or a list')

if self.task_type in REGRESSION_TASKS:
if not isinstance(self.configuration, Configuration):
Expand Down Expand Up @@ -328,14 +335,14 @@ def file_output(
Y_valid_pred,
Y_test_pred
):
# TODO refactor this function to only output and calculate the loss
# for one specific data set - optimization, validation or test!
seed = self.seed

# Abort if self.Y_optimization is None
# self.Y_optimization can be None if we use partial-cv, then,
# obviously no output should be saved.
if self.Y_optimization is not None and \
self.Y_optimization.shape[0] != Y_optimization_pred.shape[0]:
if self.Y_optimization is None:
return None, {}

# Abort in case of shape misalignment
if self.Y_optimization.shape[0] != Y_optimization_pred.shape[0]:
return (
1.0,
{
Expand All @@ -346,6 +353,7 @@ def file_output(
},
)

# Abort if predictions contain NaNs
for y, s in [
# Y_train_pred deleted here. Fix unittest accordingly.
[Y_optimization_pred, 'optimization'],
Expand All @@ -361,43 +369,99 @@ def file_output(
},
)

num_run = str(self.num_run)
# Abort if we don't want to output anything.
# Since disable_file_output can also be a list, we have to explicitly
# compare it with True.
if self.disable_file_output is True:
return None, {}

if (
self.disable_file_output != True and (
not isinstance(self.disable_file_output, list)
or 'model' not in self.disable_file_output
)
):
if os.path.exists(self.backend.get_model_dir()):
self.backend.save_model(self.model, self.num_run, seed, self.budget)
# Notice that disable_file_output==False and disable_file_output==[]
# means the same thing here.
if self.disable_file_output is False:
self.disable_file_output = []

if (
self.disable_file_output != True and (
not isinstance(self.disable_file_output, list)
or 'y_optimization' not in self.disable_file_output
)
):
# This file can be written independently of the others down bellow
if ('y_optimization' not in self.disable_file_output):
if self.output_y_hat_optimization:
try:
os.makedirs(self.backend.output_directory)
except OSError:
pass
self.backend.save_targets_ensemble(self.Y_optimization)

self.backend.save_predictions_as_npy(
Y_optimization_pred, 'ensemble', seed, num_run, self.budget,
)
# The other four files have to be written together, meaning we start
# writing them just after acquiring the locks for all of them.
# But first we have to check which files have to be written.
write_tasks = []

# File 1 of 4: model
if ('model' not in self.disable_file_output):
if os.path.exists(self.backend.get_model_dir()):
file_path = self.backend.get_model_path(
self.seed, self.num_run, self.budget)
write_tasks.append(
WriteTask(
lock=lockfile.LockFile(file_path),
writer=self.backend.save_model,
args=(self.model, file_path)
))

# File 2 of 4: predictions
if ('y_optimization' not in self.disable_file_output):
file_path = self.backend.get_prediction_output_path(
'ensemble', self.seed, self.num_run, self.budget)
write_tasks.append(
WriteTask(
lock=lockfile.LockFile(file_path),
writer=self.backend.save_predictions_as_npy,
args=(Y_optimization_pred, file_path)
))

# File 3 of 4: validation predictions
if Y_valid_pred is not None:
if self.disable_file_output != True:
self.backend.save_predictions_as_npy(Y_valid_pred, 'valid',
seed, num_run, self.budget)

file_path = self.backend.get_prediction_output_path(
'valid', self.seed, self.num_run, self.budget)
write_tasks.append(
WriteTask(
lock=lockfile.LockFile(file_path),
writer=self.backend.save_predictions_as_npy,
args=(Y_valid_pred, file_path)
))

# File 4 of 4: test predictions
if Y_test_pred is not None:
if self.disable_file_output != True:
self.backend.save_predictions_as_npy(Y_test_pred, 'test',
seed, num_run, self.budget)
file_path = self.backend.get_prediction_output_path(
'test', self.seed, self.num_run, self.budget)
write_tasks.append(
WriteTask(
lock=lockfile.LockFile(file_path),
writer=self.backend.save_predictions_as_npy,
args=(Y_test_pred, file_path)
))

# We then acquire the locks one by one in a stubborn fashion, i.e. if a file is
# already locked, we keep probing it until it is unlocked. This will NOT create a
# race condition with _delete_non_candidate_models() since this function doesn't
# acquire the locks in this stubborn way. The delete function releases all the
# locks and aborts the acquision process as soon as it finds a locked file.
for wt in write_tasks:
while True:
try:
wt.lock.acquire()
break
except lockfile.AlreadyLocked:
time.sleep(.1)
continue
except Exception as e:
raise RuntimeError('Failed to lock %s due to %s' % (wt.lock, e))

# At this point we are good to write the files
for wt in write_tasks:
wt.writer(*wt.args)

# And finally release the locks
for wt in write_tasks:
wt.lock.release()

return None, {}

Expand Down

0 comments on commit dc5537e

Please sign in to comment.