In [10]:
import os
import wandb
import torch
from pathlib import Path

def download_and_modify_weights(run_id, output_folder):
    # Initialize wandb API
    api = wandb.Api()

    # Get the run
    try:
        run = api.run(run_id)
    except wandb.errors.CommError:
        print(f"Error: Unable to find run with id {run_id}")
        return

    # Find the file ending with 'best_success_rate.pt'
    best_file = None
    for file in run.files():
        if file.name.endswith('best_success_rate.pt'):
            best_file = file
            break

    if not best_file:
        print(f"Error: No file ending with 'best_success_rate.pt' found for run {run_id}")
        return

    # Download the file
    best_file.download(replace=True)

    # Load the weights
    weights = torch.load(best_file.name)

    # Add the run config to the weights
    weights['config'] = run.config

    # Create the output folder if it doesn't exist
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    # Save the modified weights with the new simplified filename
    output_path = output_folder / "actor_chkpt.pt"
    
    torch.save(weights, str(output_path))

    print(f"Modified weights for run {run_id} saved to {output_path}")

def process_runs(run_dict):
    checkpoint_path = Path(os.environ.get('CHECKPOINT_PATH', ''))
    
    for output_folder, run_id in run_dict.items():
        full_output_path = checkpoint_path / output_folder
        print(f"Processing run {run_id} to output folder {full_output_path}")
        download_and_modify_weights(run_id, full_output_path)

In [11]:
os.environ["CHECKPOINT_PATH"] = "/data/scratch/ankile/rr-best-checkpoints"

In [12]:
# Example usage
runs_to_process = {
    "bc/one_leg/low": "ol-state-dr-1/e3d4a367",
    "bc/one_leg/med": "ol-state-dr-med-1/9zjnzg4r",
    "bc/round_table/low": None,
    "bc/round_table/med": None,
    "bc/lamp/low": None,
    "bc/lamp/med": None,
    "bc/mug_rack/low": None,
    "bc/factory_peg_hole/low": None,
    "rppo/one_leg/low": None,
    "rppo/one_leg/med": None,
    "rppo/round_table/low": None,
    "rppo/round_table/med": None,
    "rppo/lamp/low": None,
    "rppo/lamp/med": None,
    "rppo/mug_rack/low": None,
    "rppo/factory_peg_hole/low": None,
}

process_runs(runs_to_process)

Processing run ol-state-dr-1/e3d4a367 to output folder /data/scratch/ankile/rr-best-checkpoints/bc/one_leg/low
Modified weights for run ol-state-dr-1/e3d4a367 saved to /data/scratch/ankile/rr-best-checkpoints/bc/one_leg/low/actor_chkpt.pt
Processing run ol-state-dr-med-1/9zjnzg4r to output folder /data/scratch/ankile/rr-best-checkpoints/bc/one_leg/med
Modified weights for run ol-state-dr-med-1/9zjnzg4r saved to /data/scratch/ankile/rr-best-checkpoints/bc/one_leg/med/actor_chkpt.pt
