We evaluate the model here

In [None]:
from gr00t.utils.eval import calc_mse_for_single_trajectory
import warnings
from gr00t.experiment.data_config import DATA_CONFIG_MAP
from gr00t.model.policy import Gr00tPolicy
from gr00t.data.schema import EmbodimentTag
from gr00t.data.dataset import LeRobotSingleDataset
import numpy as np
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

warnings.simplefilter("ignore", category=FutureWarning)

PRE_TRAINED_MODEL_PATH = "/mloscratch/users/kalajdzi/track-hawk/checkpoints/checkpoint-5000"
EMBODIMENT_TAG = EmbodimentTag.NEW_EMBODIMENT
DATASET_PATH = "/mloscratch/users/kalajdzi/track-hawk/data_track_hawk/dataset_drone_control"


data_config = DATA_CONFIG_MAP["track_hawk"]
modality_config = data_config.modality_config()
modality_transform = data_config.transform()


pre_trained_policy = Gr00tPolicy(
    model_path=PRE_TRAINED_MODEL_PATH,
    embodiment_tag=EMBODIMENT_TAG,
    modality_config=modality_config,
    modality_transform=modality_transform,
    device=device,
)

dataset = LeRobotSingleDataset(
    dataset_path=DATASET_PATH,
    modality_configs=modality_config,
    video_backend="decord",
    video_backend_kwargs=None,
    transforms=None,  # We'll handle transforms separately through the policy
    embodiment_tag=EMBODIMENT_TAG,
)

In [None]:
for traj_id in range(6):
    print("Running trajectory:", traj_id)
    mse = calc_mse_for_single_trajectory(
        pre_trained_policy,
        dataset,
        traj_id=traj_id,
        state_modality_keys=["drone_state"],
        action_modality_keys=["drone_action"],
        steps=906,
        action_horizon=16,
        plot=True
    )

    print("MSE loss for trajectory ", traj_id, ":" , mse)
    print("==============================")