Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New-style checkpointing (again) #307

Merged
merged 15 commits into from
Oct 3, 2023
Merged

New-style checkpointing (again) #307

merged 15 commits into from
Oct 3, 2023

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Oct 2, 2023

Switches to PyTorch's new recommended checkpointing functionality - torch.distributed.checkpoint.

The benefits of using the new checkpointing module is that we can save and load (sharded) checkpoints from different world sizes M and N, even if M or N is 1, and the total size of the checkpoints should be much smaller than the artifacts from our current sharded checkpointing method.

In order to make this work smoothly on MosaicML or other platforms where there isn't a shared file system between nodes I had to implement a custom StorageWriter and StorageReader. These classes - RemoteFileSystemWriter and RemoteFileSystemReader, respectively - work just like the standard PyTorch FileSystemWriter and FileSystemReader when writing/reading checkpoints to/from a local directory, but they are also capable of writing/reading to/from cloud storage, which is necessary when nodes don't have access to a shared file system.

These changes are backwards compatible in that we can still load our "old-style" sharded checkpoints so we can resume an existing run after this merges.

@epwalsh
Copy link
Member Author

epwalsh commented Oct 2, 2023

Still in "draft" mode because I have yet to test this on LUMI. I will tomorrow when it's back up.

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This format means we'd need a new unsharder, but we can write it in gloo and without hacks. We just spin up 256 ranks in separate processes on a single machine, and have them load a checkpoint, and call the save_unsharded() function.

fut = super().write_data(plan, planner)
if self.upload_to is not None:
files_to_upload = set()
for write_result in fut.value():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .value() here means the first future needs to be completed at this point. I assume it waits? Wouldn't it be better (and possibly necessary for correctness?) to wait inside the thread pool?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do they have this system with the futures if they are not using it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the API for this class. Maybe they had another use case in mind when they designed it, I don't know. I added a .wait() just in case FileSystemWriter changes in a later release. 7c7f6dc

Comment on lines +72 to +73
for f in as_completed(futures):
f.result()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, this thing is returning futures, but not the futures that are doing the uploading? Is that right? Why?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. The future it returns is a PyTorch Future, not a Future from the Python std lib.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😵‍💫
No way to convert one to the other? No benefit of doing so?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that I know of

olmo/train.py Outdated
# Load the model state dict in place.
log.info("Loading model state...")
model_state = {"model": self.fsdp_model.state_dict()}
load_state_dict(model_state, RemoteFileSystemReader(f"{load_path}/model_and_optim"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the presence of the model_and_optim directory will we know that this is a new-style checkpoint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Comment on lines +380 to +382
# Restoring RNG state isn't necessary and in the case of going from world size 1 to world size N
# we probably don't want every rank to have the exact same RNG state.
del trainer_state["rng"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a RNG point of view, this means we get a different model when switching world sizes mid-training, I assume?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.. but there's no way to avoid that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as the files with the indices reflect what happened, it's all good.

olmo/train.py Outdated

barrier()

def restore_legacy_sharded_checkpoint(self, load_path: PathOrStr):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this is not too much code, but I'd be fine saying we only restore legacy unsharded checkpoints. At least long term. Tomorrow I want to start running immediately on one of those, which hasn't been unsharded yet. So maybe it's good this is here.

Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious to see how it goes on LUMI

try:
resource_path(load_path, f"rank{get_global_rank()}.pt")
legacy_mode = True
except FileNotFoundError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to implicitly imply that FileNotFoundError will happen if the file passed to resource_path does not exist. However, the else block of resource_path looks like it can be satisfied by a local file that does not exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit irrelevant, but why not just remove the else block of resource_path? It looks like cached_path will check existence of local files.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to implicitly imply that FileNotFoundError will happen if the file passed to resource_path does not exist. However, the else block of resource_path looks like it can be satisfied by a local file that does not exist.

Good catch. f357b5e

Maybe a bit irrelevant, but why not just remove the else block of resource_path? It looks like cached_path will check existence of local files.

Technically that would break for local files Windows OS where the path separator is not a "/". Not like we'll be training on Windows anyway, but might as well avoid potential bugs when possible.

Copy link
Collaborator

@2015aroras 2015aroras Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically that would break for local files Windows OS where the path separator is not a "/".

I would think that the Python path abstraction would be able to deal with the alternate path separator. The doc does seem to allow forward slashes for windows paths: https://docs.python.org/3/library/pathlib.html#pathlib.PureWindowsPath.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2015aroras you're right. deeb8fb

olmo/train.py Outdated
try:
train_state_dict = torch.load(resource_path(load_path, "other.pt")) # for backwards compatibility
except FileNotFoundError:
train_state_dict = torch.load(resource_path(load_path, "train.pt"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe let's try train.pt first, and fallback to the legacy if needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fut = super().write_data(plan, planner)
if self.upload_to is not None:
files_to_upload = set()
for write_result in fut.value():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh
Copy link
Member Author

epwalsh commented Oct 3, 2023

For the medium model on LUMI, this decreases the total sharded checkpoint size from 200G to 28G! 😮

@dirkgr
Copy link
Member

dirkgr commented Oct 3, 2023

^ @2015aroras, this is why we have to unshard all the checkpoints. Even the ones we already made with the old method.

@epwalsh
Copy link
Member Author

epwalsh commented Oct 3, 2023

Saving / loading from LUMI to S3 is pretty quick too for the medium model (1-2 mins).

@epwalsh
Copy link
Member Author

epwalsh commented Oct 3, 2023

As of bb16bd7, unsharding a new-style sharded checkpoint is as simple as this:

from olmo import Olmo, CheckpointType

model = Olmo.from_checkpoint(
  "path/to/sharded/checkpoint",
  device="cpu",  # "cuda" works fine too, and might be faster
  checkpoint_type=CheckpointType.sharded,
)
torch.save(model.state_dict(), "path/to/unsharded/checkpoint/model.pt")

@epwalsh epwalsh marked this pull request as ready for review October 3, 2023 22:47
@epwalsh epwalsh merged commit 602968a into main Oct 3, 2023
10 checks passed
@epwalsh epwalsh deleted the petew/checkpointing branch October 3, 2023 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants