-
Notifications
You must be signed in to change notification settings - Fork 401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New-style checkpointing (again) #307
Conversation
Automatically detect the old format and handle accordingly.
Still in "draft" mode because I have yet to test this on LUMI. I will tomorrow when it's back up. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This format means we'd need a new unsharder, but we can write it in gloo and without hacks. We just spin up 256 ranks in separate processes on a single machine, and have them load a checkpoint, and call the save_unsharded()
function.
olmo/checkpoint.py
Outdated
fut = super().write_data(plan, planner) | ||
if self.upload_to is not None: | ||
files_to_upload = set() | ||
for write_result in fut.value(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Calling .value()
here means the first future needs to be completed at this point. I assume it waits? Wouldn't it be better (and possibly necessary for correctness?) to wait inside the thread pool?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The future needs to get waited on I think.
https://pytorch.org/docs/stable/futures.html#torch.futures.Future.value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value of this particular future is available immediately. See https://github.com/pytorch/pytorch/blob/7827ae2864afa1955bc9ce04d168b274700d24e5/torch/distributed/checkpoint/filesystem.py#L429.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do they have this system with the futures if they are not using it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the API for this class. Maybe they had another use case in mind when they designed it, I don't know. I added a .wait()
just in case FileSystemWriter
changes in a later release. 7c7f6dc
for f in as_completed(futures): | ||
f.result() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, this thing is returning futures, but not the futures that are doing the uploading? Is that right? Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. The future it returns is a PyTorch Future
, not a Future
from the Python std lib.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😵💫
No way to convert one to the other? No benefit of doing so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not that I know of
olmo/train.py
Outdated
# Load the model state dict in place. | ||
log.info("Loading model state...") | ||
model_state = {"model": self.fsdp_model.state_dict()} | ||
load_state_dict(model_state, RemoteFileSystemReader(f"{load_path}/model_and_optim")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the presence of the model_and_optim
directory will we know that this is a new-style checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
# Restoring RNG state isn't necessary and in the case of going from world size 1 to world size N | ||
# we probably don't want every rank to have the exact same RNG state. | ||
del trainer_state["rng"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a RNG point of view, this means we get a different model when switching world sizes mid-training, I assume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.. but there's no way to avoid that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as the files with the indices reflect what happened, it's all good.
olmo/train.py
Outdated
|
||
barrier() | ||
|
||
def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this is not too much code, but I'd be fine saying we only restore legacy unsharded checkpoints. At least long term. Tomorrow I want to start running immediately on one of those, which hasn't been unsharded yet. So maybe it's good this is here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious to see how it goes on LUMI
try: | ||
resource_path(load_path, f"rank{get_global_rank()}.pt") | ||
legacy_mode = True | ||
except FileNotFoundError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to implicitly imply that FileNotFoundError
will happen if the file passed to resource_path
does not exist. However, the else block of resource_path
looks like it can be satisfied by a local file that does not exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a bit irrelevant, but why not just remove the else block of resource_path
? It looks like cached_path
will check existence of local files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to implicitly imply that FileNotFoundError will happen if the file passed to resource_path does not exist. However, the else block of resource_path looks like it can be satisfied by a local file that does not exist.
Good catch. f357b5e
Maybe a bit irrelevant, but why not just remove the else block of resource_path? It looks like cached_path will check existence of local files.
Technically that would break for local files Windows OS where the path separator is not a "/". Not like we'll be training on Windows anyway, but might as well avoid potential bugs when possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically that would break for local files Windows OS where the path separator is not a "/".
I would think that the Python path abstraction would be able to deal with the alternate path separator. The doc does seem to allow forward slashes for windows paths: https://docs.python.org/3/library/pathlib.html#pathlib.PureWindowsPath.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@2015aroras you're right. deeb8fb
olmo/train.py
Outdated
try: | ||
train_state_dict = torch.load(resource_path(load_path, "other.pt")) # for backwards compatibility | ||
except FileNotFoundError: | ||
train_state_dict = torch.load(resource_path(load_path, "train.pt")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Maybe let's try train.pt
first, and fallback to the legacy if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
olmo/checkpoint.py
Outdated
fut = super().write_data(plan, planner) | ||
if self.upload_to is not None: | ||
files_to_upload = set() | ||
for write_result in fut.value(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The future needs to get waited on I think.
https://pytorch.org/docs/stable/futures.html#torch.futures.Future.value
For the medium model on LUMI, this decreases the total sharded checkpoint size from 200G to 28G! 😮 |
^ @2015aroras, this is why we have to unshard all the checkpoints. Even the ones we already made with the old method. |
Saving / loading from LUMI to S3 is pretty quick too for the medium model (1-2 mins). |
As of bb16bd7, unsharding a new-style sharded checkpoint is as simple as this: from olmo import Olmo, CheckpointType
model = Olmo.from_checkpoint(
"path/to/sharded/checkpoint",
device="cpu", # "cuda" works fine too, and might be faster
checkpoint_type=CheckpointType.sharded,
)
torch.save(model.state_dict(), "path/to/unsharded/checkpoint/model.pt") |
Switches to PyTorch's new recommended checkpointing functionality -
torch.distributed.checkpoint
.The benefits of using the new checkpointing module is that we can save and load (sharded) checkpoints from different world sizes M and N, even if M or N is 1, and the total size of the checkpoints should be much smaller than the artifacts from our current sharded checkpointing method.
In order to make this work smoothly on MosaicML or other platforms where there isn't a shared file system between nodes I had to implement a custom
StorageWriter
andStorageReader
. These classes -RemoteFileSystemWriter
andRemoteFileSystemReader
, respectively - work just like the standard PyTorchFileSystemWriter
andFileSystemReader
when writing/reading checkpoints to/from a local directory, but they are also capable of writing/reading to/from cloud storage, which is necessary when nodes don't have access to a shared file system.These changes are backwards compatible in that we can still load our "old-style" sharded checkpoints so we can resume an existing run after this merges.