-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support checkpointing and reloading models #32
Conversation
import os | ||
import re | ||
from typing import Tuple, Optional | ||
import xgboost |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, import xgboost as xgb (and update the corresponding references).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
""" | ||
# TODO: replace with CSDK constants in sagemaker_containers._env | ||
train_config = json.load(open(os.getenv(sm_env_constants.SM_INPUT_TRAINING_CONFIG_FILE), "r")) | ||
data_config = json.load(open(os.getenv(sm_env_constants.SM_INPUT_DATA_CONFIG_FILE), "r")) | ||
|
||
checkpoint_config_file = os.getenv(sm_env_constants.SM_CHECKPOINT_CONFIG_FILE) | ||
if os.path.exists(checkpoint_config_file): | ||
checkpoint_config = json.load(open(checkpoint_config_file, "r")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my understanding, it's not necessary to explicitly close an opened file, but this is generally considered bad practice. I notice that the same is done in the lines above (44-45). Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I should have changed the previous lines instead of copying them. Changed.
|
||
file_path = os.path.join(self.test_dir, "xgboost-checkpoint.000000000042") | ||
self.assertTrue(os.path.isfile(file_path)) | ||
self.assertTrue(len(os.listdir(self.test_dir)), 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming no file markers are generated in these unit tests, but do you check for/test them in the integration tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some basic unit tests for marker files in test_SaveCheckpoint_uploading
. Integration tests don't really check for marker files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, it might be helpful for documentation to have some comments explaining what that test is doing, since it's not immediately clear to me what it's doing. I think the other tests are self-explanatory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On that note, it would be good to include the information you've provided in this PR as comments in the code so it's clear for another developer in the future what you've done.
def __init__( | ||
self, checkpoint_dir: str, max_to_keep: int = 5, start_iteration: int = 0, | ||
num_round: Optional[int] = None | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ending parens should be in the same line as last element, see pep: https://www.python.org/dev/peps/pep-0008/#indentation
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def _zero_pad(iteration: int) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Method names should ideally start with an action, i.e. '_add_zero_pad' unless it is a property
return iter_num | ||
|
||
|
||
# modified from https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making sure the main change is the addition of the method argument start_iteration
? Can you explain why this was required in comments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this method name to print_checkpointed_evaluation?
|
||
return self.callback(env) | ||
|
||
def fmt_path(self, i: int) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is i? Why are we zero_padding the file name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, if we are keeping this method, can you just rename it to format_path? You don't gain anything by removing the couple of letters, and you lose readability later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change the variable and function names to be more descriptive.
On zero padding, I used the same format that tensorboard uses to save tfevent files (I just happened to be working on something else that uses tensorboard at the time). I think the rationale is by zero padding, it becomes string-sortable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that is the case, can you describe the rational for how much zero padding you want to do and document? Sortability is nice, but if you have a limited number of zero padding it affects scaleability. In general I would rather opt for scaleability rather than sortability; also why would users have to sort the checkpoints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By scaleability, do you mean we might run out of zeros, like .999999999999
-> .000000000000
. Choosing 12 digits was somewhat arbitrary, but I don't expect anyone to train for trillion rounds.
Sortability is useful in train_utils.load_checkpoint()
where I sort the checkpoints and try to load the latest checkpoint, but that can easily be changed if we remove zero-padding. I think you have a point; from the user's perspective, zero-padding is unnecessary.
It's a valid point, so I could get rid of zero-padding and modify the sorting method in load_checkpoint()
. Do you think that's better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the trillion rounds, but no need to add limits that aren't necessary.
I think we should remove the zero-padding if there is no customer usecase for it.
etc. However, this could result in a large number of files to save locally | ||
(and on S3 if used in SageMaker). To save disk space and reduce the amount | ||
of data required to be downloaded on resumption, we retain only the N | ||
(default 5, spcified by 'max_to_keep') most recent checkpoints. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on my understanding, if we have checkpoints queued up, the upload would only happen on the most recent checkpoint. So we retain n - 1 most recent uploaded checkpoints the queued checkpoints that have not been uploaded yet.
If my understanding is incorrect please let me know.
self, checkpoint_dir: str, max_to_keep: int = 5, start_iteration: int = 0, | ||
num_round: Optional[int] = None | ||
) -> None: | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The formatting of this file is different from the other files in the repo. Is this an artifact of the copy/paste of the print_evaluation method? I would prefer if copied methods are put in another file, and this file follow formatting patterns for cohesiveness.
|
||
|
||
class SaveCheckpoint: | ||
"""Create a callback that saves checkpoints to disk. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: Doc reads like method documentation
of data required to be downloaded on resumption, we retain only the N | ||
(default 5, spcified by 'max_to_keep') most recent checkpoints. | ||
|
||
To delete stale checkpoints, we use a producer-consumer pattern: we start a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doc is nice, but we should also have comments near the code that actually executes it.
|
||
path = self.fmt_path(i) | ||
if (skip_locked_files | ||
and os.path.isfile(path + FILE_LOCK_SUFFIX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move these two .isfile()
checks to a well named method for readability? I would also re-document the behavior here; whenever we have asynchronous processes I'd like to make sure its really clear what's happening.
self.delete_queue.put(self.SENTINEL) | ||
self.thread.join() | ||
|
||
def callback(self, env: CallbackEnv) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you describe what is being done here? Between the file suffixes and trying to keep up with the file names its a bit confusing.
return callback | ||
|
||
|
||
class SaveCheckpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this class seems much more complex than needed, or not complex enough. If we are deleting files based on 'flag' files why do we need to maintain a queue to remove files; wouldn't it be easier to just remove all files that we recognize as done?
On the other side, I'm not as familiar with spot instances; what exactly is the behavior when a spot instance is taken away? Does the kernel recognize it as a interrupt? Is this process supposed to block on uploading the files, or do we just take it as best attempt and take whatever was managed to be saved in s3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are constrained by how spot instances work and the way EASE uploads. When a spot instance is about to be taken away, we get an interruption notice via CloudWatch, and you get a two-minute window before the instance goes away. By the time the kernel gets a signal, it's too late; this will be after that two-minute warning expired and kernel is in the process of shutting down. EASE does not monitor this signal nor CW; it just keeps uploading the files until the instance shuts down, relying on the fact that all S3 uploads are all-or-nothing.
So, there is no guarantee that a callback will be executed inside that window before the instance is interrupted. We have to make our best attempt along the way without relying on the callback being called; hence the asynchronous deletion. And I'm seeing more stable performance with the async approach vs. deleting in the callback (simply removing files marked as done was actually my initial approach; sometimes I would see a very slow training here and there, probably by the callback being held up by deletion), although I will have to spend some time doing profiling and tests to back up that statement with a statistically significant result.
@@ -0,0 +1,205 @@ | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this file to checkpointing.py? Let's be specific about the usecase of the code in the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, how about just checkpoint.py
? Can I put pull it out one level up and put it as src/sagemaker_xgboost-container/checkpoint.py
? I'm working on testing a script for the script mode, and this was intended to work both in script mode and algorithm mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see serving.py
, training.py
, so checkpointing.py
. Let's do checkpointing.py
.
|
||
os.rename(tf.name, self.fmt_path(i)) | ||
|
||
target_file = self.fmt_path(i - self.max_to_keep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the intution for this index please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
target_file
is what I want to put on the queue. For example, if the current iteration is 5 (i.e., i == 5
; I will use better variable names in the update), I want to put xgboost-checkpoint.0
on the queue, because I only want to keep xgboost-checkpoint.1
to .5
. So when i == 5
, i - self.max_to_keep == 0
, and self.fmt_path(i - self.max_to_keep)
returns /path/to/xgboost-checkpoint.0
. Will update with better variable names, helper functions, and code comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking better.
return callback | ||
|
||
|
||
def load_checkpoint(checkpoint_dir, max_try=5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are transient errors for an attempt to fail? Waiting for sagemaker to copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training won't start until after all the checkpoints have been downloaded, so we are safe there. The retries are for checking if we have corrupted files (e.g., partial uploads from a previous run).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we retry for corrupted files? those files will never successfully load right? Wouldn't it be better to fail fast?
# a delete signal to the agent so that the upload can be canceled and removed | ||
# from S3. | ||
self.delete_queue.put(self.SENTINEL) | ||
_delete_once(skip_locked_files=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the normal case, the second _delete_once doesn't so anything because the first call should only return once everything in the queue is removed right? otherwise won't it continue to loop while getting from queue?
This second _delete_once call is confusing to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in the normal case, that function exits immediately and doesn't do anything.
I think my attempt to reuse the same function for two different things hurt readability here. I'm going to refactor and break _delete_once()
into two functions even if they will share some code. I hope it improves readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add, the reason I'm going through everything twice is to avoid potential file corruption when we delete a file that is still being uploaded, because the EASE team was concerned about this aspect when I talked to them. Again, the second call won't do anything in normal cases. Added code comments to address this point, and hopefully the refactor makes code easier to read.
uploaded = os.path.isfile(path + FILE_SAFE_SUFFIX) | ||
return uploading and not uploaded | ||
|
||
def _delete_once(skip_locked_files=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just call this _delete_uploaded_files?
return checkpoint_path | ||
|
||
def start(self): | ||
"""Starts a background thread that deletes old checkpoints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP 8 dictates that the docstring fgor methods read as a command, not a description:
"""Start background thread that ...
See more here: https://www.python.org/dev/peps/pep-0257/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know this. Good to know.
# (and on S3 if used in SageMaker). To save disk space and reduce the amount | ||
# of data required to be downloaded on resumption, we retain only the N | ||
# (default 5, spcified by 'max_to_keep') most recent checkpoints. | ||
file_to_delete = self.format_path(current_iteration - self.max_to_keep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm personally not a fan of huge inline block comments; if you think a piece of code requires this level of commenting to make sense, it usually means the code is too confusing. I would consider refactoring to make it more clear, or even moving to its own helper method for clarity.
# of data required to be downloaded on resumption, we retain only the N | ||
# (default 5, spcified by 'max_to_keep') most recent checkpoints. | ||
file_to_delete = self.format_path(current_iteration - self.max_to_keep) | ||
file_to_delete_exists = os.path.isfile(file_to_delete) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of checking this here, I would just check during the actual delete. For example, just because the file exists now doesn't necessarily mean it will exist at deletion time, so this check isn't very protective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this. I think this is a good suggestion.
offset_iteration = env.end_iteration if self.num_round is None else self.num_round | ||
training_has_ended = (current_iteration + 1 >= self.start_iteration + offset_iteration) | ||
if training_has_ended: | ||
self.stop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flake8 requires new line at end of file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true? When I include a new line, I get W391 blank line at end of file
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run python3 -m tox
with open(os.getenv(sm_env_constants.SM_INPUT_DATA_CONFIG_FILE), "r") as f: | ||
data_config = json.load(f) | ||
|
||
checkpoint_config_file = os.getenv(sm_env_constants.SM_CHECKPOINT_CONFIG_FILE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nitpick, but I would just follow the pattern established above and skip the variable initialization for checkpoint_config_file.
train_args["xgb_model"] = xgb_model | ||
train_args["callbacks"] = callbacks | ||
# xgboost's default value for num_boost_round is 10. | ||
num_boost_round = train_args.get("num_boost_round", 10) - start_iteration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a thought exercise; what if the default for xgb changes to something other than 10?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we try to expose this constant upstream? We could do a PR that adds something like DEFAULT_NUM_BOOST_ROUND = 10
in https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/training.py and replaces all num_boost_round=10
with num_boost_round=DEFAULT_NUM_BOOST_ROUND
.
# xgboost's default value for num_boost_round is 10. | ||
num_boost_round = train_args.get("num_boost_round", 10) - start_iteration | ||
# if last checkpoint is greater than num_boost_round, we shouldn't train. | ||
train_args["num_boost_round"] = max(0, num_boost_round) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_boost_round is negative, shouldn't we just return the model? you call train on the next line which seems unnecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need to call xgb.train()
because xgb_model
is a filename (str) while we have to return a booster object. But that line seems unnecessary. Removed it after testing xgb.train()
with negative numbers and reading the source code upstream. Passing zero or negative number of rounds results in no training.
return booster | ||
|
||
|
||
# modified from https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of inline comments. Can we condense this information (lets not include log lines in docs) and put them in the method docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I realized this is the method you lifted. Let's condense the comments and not include log lines if possible!
return callback | ||
|
||
|
||
def load_checkpoint(checkpoint_dir, max_try=5): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we retry for corrupted files? those files will never successfully load right? Wouldn't it be better to fail fast?
offset_iteration = env.end_iteration if self.num_round is None else self.num_round | ||
training_has_ended = (current_iteration + 1 >= self.start_iteration + offset_iteration) | ||
if training_has_ended: | ||
self.stop() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you run python3 -m tox
class TestSaveCheckpoint(unittest.TestCase): | ||
|
||
def setUp(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nitpick and personal choice, but you have a lot of unneeded new lines in this file :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On retrying, the thought was there's a chance that some files are intact even if others are corrupt. We retain N=5 checkpoints (N retires is basically why we save N checkpoints), and I think it's much cheaper for users to save N checkpoints and try them all than to train from scratch if 1 fails.
Ran python3 -m tox
and confirmed that all tests/flake8 pass.
Finished one final round of testing. Unit tests passed:
Functional tests and integration tests passed. See CR-12948549. I think it's ready to be merged. |
Issue #, if available:
Description of changes:
Summary
callback.SaveCheckpoint
: A callback for saving checkpoints.train_utils.load_checkpoint
: For reloading models upon resuming training.test/unit/algorithm_mode/test_callback.py
.Tests
pytest test/unit
Context
This PR is to support checkpointing for managed spot training, which was recently launched. Since spot instances can be interrupted at anytime, we need to be able to save checkpoints during training, and we also need to be able to resume from the last checkpoint when the training job is restarted.
To use managed spot training, customers specify
EnableManagedSpotTraining
andCheckpointConfig
in CreateTrainingJob or the corresponding parameters in the SDK:When customers specify
CheckpoiintConfig
, theLocalPath
becomes available in the config file at/opt/ml/input/config/checkpointconfig.json
:Any files saved at this location will automatically be uploaded to S3. When training job is resumed from interruption, all uploaded files will be downloaded to the same location before training restarts.
Implementation
We first check if the config file is present at
/opt/ml/input/config/checkpointconfig.json
. If there is no config file, we don't save checkpoints. If it is available, we instantiatecallback.SaveCheckpoint
and add it to the list of callback functions that are called at the end of each iteration.SaveCheckpoint
saves each checkpoint to different files at the end of each iteration by appending the iteration number to the file name, e.g.,xgboost-checkpont.1
,xgboost-checkpont.2
, and so on. When files are written to the directory specified byLocalPath
, SM will automatically upload all files to the S3 location specified byS3Uri
.Since saving one checkpoint per iteration could result in a large number of files to save in S3 and download when spot instances are resumed, we retain only the 5 most recent checkpoints in the directory. This is accomplished by a background thread that deletes all checkpoints older than 5 most recent checkpoints (the number of files to keep is somewhat arbitrary, choosing the optimal number of files to keep is left for future work). Note that when a file is being uploaded by SageMaker, SM will create a marker file (file name +
.sagemaker-uploading
) to indicate that the file is being uploaded. SM will also create another marker file (file name +.sagemaker-uploaded
) when the upload is completed. Thus, the background will skip deleting a file and try again later if there is a marker file<filename>.sagemaker-uploading
and only attempt to delete a file when the marker file<filename>.sagemaker-uploaded
is present.CloudWatch logs
Suppose the customer is training for 100 rounds, and the spot instance is interrupted after iteration 49. When the training job is resumed, we want to load
xgboost-checkpont.49
and restart training from iteration 50. We use the iteration numbers in the file names in our modified version ofprint_evaluation()
function to provide logs that are consistent with what customers would expect.First phase: Starting ↠ Downloading ↠ Training ↠ Interrupted
Second phase: Resumed ↠ Starting ↠ Downloading ↠ Training
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.