In [1]:
"""
Usage:
python eval.py --checkpoint data/image/pusht/diffusion_policy_cnn/train_0/checkpoints/latest.ckpt -o data/pusht_eval_output
"""

import sys
# use line-buffering for both stdout and stderr
# sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
# sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1)

import os
import pathlib
import hydra
import torch
import dill
import wandb
import json
import time
from diffusion_policy.workspace.base_workspace import BaseWorkspace

In [2]:
checkpoint = 'models/pre-trained/epoch=0100-test_mean_score=0.748.ckpt'
output_dir = 'data/pusht_eval_output'
device = 'cuda:0'

# eval
def main(checkpoint, output_dir, device):
    print("================================ Setup ================================")

    # load checkpoint
    payload = torch.load(open(checkpoint, 'rb'), pickle_module=dill)
    cfg = payload['cfg']
    cls = hydra.utils.get_class(cfg._target_)
    print("Workspace: {}".format(cfg._target_.split('.')[-1]))
    print("policy: {}".format(cfg.policy._target_.split('.')[-1]))
    print("environment: {}".format(cfg.task.env_runner._target_.split('.')[-1]))

    # initialize workspace
    workspace = cls(cfg, output_dir=output_dir)
    workspace: BaseWorkspace
    workspace.load_payload(payload, exclude_keys=None, include_keys=None)

    # get policy from workspace
    policy = workspace.model
    if cfg.training.use_ema:
        policy = workspace.ema_model
        
    device = torch.device(device)
    policy.to(device)
    policy.eval()

    # log cfg of policy
    print("----- Policy Setup -----")
    print("T: {}, To: {}".format(policy.horizon, policy.n_obs_steps))
    print("Da: {}, Do: {}".format(policy.action_dim, policy.obs_feature_dim))
    print("Data Type: {}, obs_as_cond: {}".format(policy.dtype, policy.obs_as_cond))
    print("pred_action_steps_only: {}".format(policy.pred_action_steps_only))
    print("n_action_steps: {}".format(policy.n_action_steps))
        
    # run eval
    env_runner = hydra.utils.instantiate(
        cfg.task.env_runner,
        output_dir=output_dir)
    start = time.time()
    runner_log = env_runner.run(policy)
    end = time.time()
    print("\n {} \n".format(end-start))

    return
    
    # dump log to json
    json_log = dict()
    for key, value in runner_log.items():
        if isinstance(value, wandb.sdk.data_types.video.Video):
            json_log[key] = value._path
        else:
            json_log[key] = value
    out_path = os.path.join(output_dir, 'eval_log.json')
    json.dump(json_log, open(out_path, 'w'), indent=2, sort_keys=True)

main(checkpoint, output_dir, device)

Workspace: TrainDiffusionTransformerHybridWorkspace
policy: DiffusionTransformerHybridImagePolicy
environment: PushTImageRunner


using obs modality: low_dim with keys: ['agent_pos']
using obs modality: rgb with keys: ['image']
using obs modality: depth with keys: []
using obs modality: scan with keys: []




----- Policy Setup -----
T: 10, To: 2
Da: 2, Do: 66
Data Type: torch.float32, obs_as_cond: True
pred_action_steps_only: False
n_action_steps: 8
pygame 2.1.2 (SDL 2.0.16, Python 3.9.18)
Hello from the pygame community. https://www.pygame.org/contribute.html
----- Environment Setup -----
n_obs_steps: 2, n_action_steps: 8
n_train: 6,     n_test: 50
max_step: 300
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           