In [1]:
import os
import wandb
import torch
from pathlib import Path

def download_and_modify_weights(run_id, output_folder):
    # Initialize wandb API
    api = wandb.Api()

    # Get the run
    try:
        run = api.run(run_id)
    except wandb.errors.CommError:
        print(f"Error: Unable to find run with id {run_id}")
        return

    # Find the file ending with 'best_success_rate.pt'
    best_file = None

    run_files = list(run.files())

    # Sort the files so that newest file is first
    run_files = sorted(run_files, key=lambda x: x.updated_at, reverse=True)

    for file in run.files():
        if file.name.endswith('best_success_rate.pt'):
            best_file = file
            break

    if not best_file:
        print(f"Error: No file ending with 'best_success_rate.pt' found for run {run_id}")
        return

    # Download the file
    best_file.download(replace=True)

    # Load the weights
    weights = torch.load(best_file.name)

    # Add the run config to the weights
    weights['config'] = run.config

    # Create the output folder if it doesn't exist
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    # Save the modified weights with the new simplified filename
    output_path = output_folder / "actor_chkpt.pt"
    
    torch.save(weights, str(output_path))

    print(f"Modified weights for run {run_id} saved to {output_path}")

def process_runs(run_dict):
    checkpoint_path = Path(os.environ.get('CHECKPOINT_PATH', ''))
    
    for output_folder, run_id in run_dict.items():
        full_output_path = checkpoint_path / output_folder
        print(f"Processing run {run_id} to output folder {full_output_path}")
        download_and_modify_weights(run_id, full_output_path)

In [2]:
os.environ["CHECKPOINT_PATH"] = "/data/scratch/ankile/rr-best-checkpoints"

In [None]:
# Example usage
runs_to_process = {
    "bc/one_leg/low": "ol-state-dr-1/e3d4a367",
    "bc/one_leg/med": "ol-state-dr-med-1/9zjnzg4r",
    "bc/round_table/low": "rt-state-dr-low-1/z3efusm6",
    "bc/round_table/med": "rt-state-dr-med-1/n5g6x9jg",
    "bc/lamp/low": "lp-state-dr-low-1/b5dcl1tt",
    "bc/lamp/med": "lp-state-dr-med-1/fziwvs8k",
    "bc/mug_rack/low": "mr-state-dr-low-1/uet1h1ex",
    "bc/factory_peg_hole/low": "fph-state-dr-low-1/4vwizwue",
    "rppo/one_leg/low": "ol-rppo-dr-low-1/jamz5ley",
    "rppo/one_leg/med": "ol-rppo-dr-med-1/oipdyimz",
    "rppo/round_table/low": "rt-rppo-dr-low-1/np48i8wp",
    "rppo/round_table/med": "rt-rppo-dr-med-1/k737s8lj",
    "rppo/lamp/low": "lp-rppo-dr-low-1/hd2i5gje",
    "rppo/lamp/med": "lp-rppo-dr-med-1/ev23t35c",
    "rppo/mug_rack/low": "mr-rppo-dr-low-1/dvw6zk8e",
    "rppo/factory_peg_hole/low": "fph-rppo-dr-low-1/2kd9vgx9",
}

process_runs(runs_to_process)