Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions code/workflows/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,32 +108,44 @@ def find_best_model(path, *, rank: int = 0):
max_value = -1 # Start with -1 to include epoch 0
best_file = None
model_files = []
# Named .pt files without epoch numbers (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt)
named_pt_files = []

for filename in os.listdir(path):
if not filename.startswith("PreDecoderModelMemory_"):
continue
try:
value = float(filename.split(".")[2]) # Gets epoch number
model_files.append((filename, value))
if value > max_value:
max_value = value
best_file = filename
except (IndexError, ValueError) as e:
print(f"Warning: could not parse epoch from filename {filename}: {e}")
if not filename.endswith(".pt"):
continue
if filename.startswith("PreDecoderModelMemory_"):
try:
value = float(filename.split(".")[2]) # Gets epoch number
model_files.append((filename, value))
if value > max_value:
max_value = value
best_file = filename
except (IndexError, ValueError) as e:
print(f"Warning: could not parse epoch from filename {filename}: {e}")
else:
named_pt_files.append(filename)

# Fall back to named .pt files when no epoch-numbered checkpoints are present
if best_file is None and named_pt_files:
named_pt_files.sort()
best_file = named_pt_files[-1]
model_files = [(f, None) for f in named_pt_files]

if rank == 0:
print(f"Found {len(model_files)} model files:")
for filename, epoch in sorted(model_files, key=lambda x: x[1]):
print(f"Found {len(model_files)} model file(s):")
for filename, epoch in sorted(model_files, key=lambda x: (x[1] is None, x[1] or 0)):
marker = "*" if filename == best_file else " "
print(f" [{marker}] {filename} (epoch {epoch})")
epoch_str = str(epoch) if epoch is not None else "n/a"
print(f" [{marker}] {filename} (epoch {epoch_str})")

if best_file is None:
raise FileNotFoundError(f"No valid model checkpoint files found in {path}")

best_model_path = os.path.join(path, best_file)
if rank == 0:
print(f"Selected best model: {best_file} (epoch {max_value})")
epoch_str = str(max_value) if max_value >= 0 else "n/a"
print(f"Selected best model: {best_file} (epoch {epoch_str})")

return best_model_path

Expand Down
Loading