Skip to content

Commit

Permalink
make deterministic, prioritize sharded
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 31, 2023
1 parent ae6fadd commit 3a38d9a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions olmo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,8 @@ def find_latest_checkpoint(dir: PathOrStr) -> Optional[PathOrStr]:
else:
latest_step = 0
latest_checkpoint: Optional[Path] = None
for path in Path(dir).glob("step*"):
# Sorting here guarantees that we prioritize sharded checkpoints over unsharded checkpoints.
for path in sorted(Path(dir).glob("step*")):
if path.is_dir():
try:
step = int(path.name.replace("step", "").replace("-unsharded", ""))
Expand Down Expand Up @@ -649,7 +650,8 @@ def _s3_find_latest_checkpoint(bucket_name: str, prefix: str) -> Optional[str]:
assert not response["IsTruncated"] # need to handle this if it happens
latest_step = 0
latest_checkpoint: Optional[str] = None
for item in response["CommonPrefixes"]:
# Sorting here guarantees that we prioritize sharded checkpoints over unsharded checkpoints.
for item in sorted(response["CommonPrefixes"], key=lambda x: x["Prefix"]):
prefix = item["Prefix"].strip("/")
checkpoint_name = os.path.split(prefix)[-1]
if not checkpoint_name.startswith("step"):
Expand Down

0 comments on commit 3a38d9a

Please sign in to comment.