Skip to content

Commit

Permalink
update runner
Browse files Browse the repository at this point in the history
  • Loading branch information
KuoHaoZeng committed Jul 9, 2024
1 parent 04c5b80 commit 515ddd0
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,28 +1502,22 @@ def get_checkpoint_files(
):
if "wandb://" == checkpoint_path_dir_or_pattern[:8]:
import wandb
run_token = checkpoint_path_dir_or_pattern.split("//")[1]
import shutil
eval_dir = "wandb_ckpts_to_eval/{}".format(self.local_start_time_str)
os.makedirs(eval_dir, exist_ok=True)
api = wandb.Api()
run = api.run(run_token)
all_checkpoints = run.files()
run_token = checkpoint_path_dir_or_pattern.split("//")[1]
ckpt_steps = checkpoint_path_dir_or_pattern.split("//")[2:]
if ckpt_steps[-1] == "":
ckpt_steps = ckpt_steps[:-1]
ckpts_paths = []
for steps in ckpt_steps:
for ckpts in all_checkpoints:
if steps in ckpts.name:
ckpts.download()
ckpts_paths.append(ckpts.name)
try:
self.checkpoint_start_time_str(ckpts_paths[0])
except:
import shutil
eval_dir = "wandb_ckpts_to_eval/{}".format(self.local_start_time_str)
os.makedirs(eval_dir, exist_ok=True)
for ckpt in ckpts_paths:
shutil.move(ckpt, os.path.join(eval_dir, ckpt))
ckpts_paths = glob.glob(os.path.join(eval_dir, "*.pt"))
ckpt_fn = "{}-step-{}:latest".format(run_token, steps)
artifact = api.artifact(ckpt_fn)
_ = artifact.download("tmp")
ckpt_dir = "{}/ckpt-{}.pt".format(eval_dir, steps)
shutil.move("tmp/ckpt.pt", ckpt_dir)
ckpts_paths.append(ckpt_dir)
return ckpts_paths

if os.path.isdir(checkpoint_path_dir_or_pattern):
Expand Down

0 comments on commit 515ddd0

Please sign in to comment.