In [20]:
import os

os.environ["OPENPI_DATA_HOME"] = (
    "/mnt/virtual_ai0001071-01239_SR006-nfs2/apanasevich/openpi/assets"
)
os.environ["HF_HOME"] = "/mnt/virtual_ai0001071-01239_SR006-nfs2/.cache/huggingface"
os.environ["XDG_CACHE_HOME"] = "/mnt/virtual_ai0001071-01239_SR006-nfs2/.cache"


from lerobot.common.datasets.create_dataloader import create_lerobot_dataloader
from lerobot.common.datasets.data_config import (
    LeRobotAgibotTwoFingerDataConfig,
    LeRobotAgibotDexHandDataConfig,
)
from rich_argparse import RichHelpFormatter
import argparse

import torch

from accelerate.utils import InitProcessGroupKwargs
from datetime import timedelta

from lerobot.common.datasets.data_config import (
    AssetsConfig as LeRobotAssetsConfig
)
from lerobot.common.datasets.data_config import DataConfig as LeRobotBaseDataConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from tqdm import tqdm
from accelerate import Accelerator
from lerobot.common.utils.normalize import RunningStats, save as save_stats
import torch.distributed as dist
from pathlib import Path
from omegaconf import OmegaConf
from hydra.utils import instantiate
from lerobot.common.datasets.create_dataloader import create_lerobot_dataset_by_config
from torch.utils.data import DataLoader


def get_data(
    dataset_config_path: str, 
    assets_dir: str, 
    action_horizon: int | None = None,
    action_dim: int | None = None
):
    cfg = OmegaConf.load(dataset_config_path)
    cfg = OmegaConf.to_container(cfg, resolve=True)

    if action_horizon is not None:
        cfg['action_horizon'] = action_horizon
    
    data_config = instantiate(cfg)
    model_cfg = PI0Config()
    lerobot_dataset = create_lerobot_dataset_by_config(
        data_config_factory=data_config,
        model_config=model_cfg,
        assets_dirs=assets_dir,
        normalization_mode="mean_std", #it does not matter
        skip_norm_stats=True,
        skip_model_transforms=True,
        return_norm_stats=False,
        )
    
    class InnerDataset(torch.utils.data.Dataset):
        def __init__(self, dataset, action_horizon, action_dim):
            self.dataset = dataset
            self.action_horizon = action_horizon
            self.action_dim = action_dim
            
            if self.action_dim is None:
                self.action_dim = self.dataset[0]['actions'].shape[1]

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            return self.dataset[idx]['actions'][:self.action_horizon, :self.action_dim]

    return InnerDataset(lerobot_dataset, action_horizon, action_dim)

In [None]:
data_config = "/mnt/virtual_ai0001071-01239_SR006-nfs2/afedorov/projects/constant_repos/lerobot-fork/lerobot/conf/robotics_dataset/individual/aloha_aij.yaml"
assets_dir = "/mnt/virtual_ai0001071-01239_SR006-nfs2/afedorov/assets"

dataset = get_data(data_config, assets_dir, action_horizon=16, action_dim=14)



In [1]:
import torch
import numpy as np

from beast.bspline_tokenizer import BSpline_Tokenizer

traj = np.array([[0.0030412564065136483,
0.011397065331214767,
0.004153039984473,
0.20769455660275007,
0.01583192780193277,
0.10795688293466486,
-0.0202178955078125],
[0.0030404060633614644,
0.011463632408906967,
0.004632541819576591,
0.20772984954523133,
0.010872839293467518,
0.1102459393226217,
-0.0202178955078125],
[0.0030399455857521803,
0.011495262690979776,
0.004964521245309686,
0.20773728740702876,
0.007439597860192517,
0.11133374251807343,
-0.020217895507812497],
[0.0030394721952192822,
0.01149696531112195,
0.0049276342347678885,
0.2077382463575408,
0.007821066276012478,
0.11139399422583293,
-0.0202178955078125],
[0.0030195886702748765,
0.011672412301677806,
0.004964528327010274,
0.2100714170168426,
0.007439353235095415,
0.11749334998581466,
-0.0202178955078125],
[0.0029248382634318813,
0.012396147060055563,
0.004979176806152141,
0.22813870532704844,
0.007285144439889449,
0.14270762937492285,
-0.0202178955078125],
[0.0027553261213416535,
0.013453141796883216,
0.00540790514949117,
0.25944059384325374,
0.0028491896913942143,
0.1796906688090619,
-0.0202178955078125],
[0.0025777912796007348,
0.014606306221451794,
0.008668525312856848,
0.2859914205749503,
-0.030537835368051047,
0.22017390890612779,
-0.0202178955078125],
[0.0024326156316412523,
0.015537746219212573,
0.013297968399127114,
0.28831794404080485,
-0.0734406098392348,
0.2533056340687476,
-0.0202178955078125],
[0.0018195291545668596,
0.01652541514347037,
0.019259580629448936,
0.30029581367453656,
-0.08167634847353575,
0.2882291086389247,
-0.0202178955078125],
[0.0020166062178000738,
0.01756942982760113,
0.03124605959617424,
0.30782634382646107,
-0.052241732592717374,
0.30849590901274454,
-0.0202178955078125],
[0.0034421046275634926,
0.019123881409740574,
0.051551754636630025,
0.31181391845498035,
0.014248069524644254,
0.31678936821874953,
-0.01983642578125],
[0.005086894716544088,
0.02043594806837358,
0.06986146817578522,
0.29504959217882487,
0.05673653510580361,
0.3259098710978283,
-0.0194549560546875],
[0.010553318155958221,
0.02400074653186707,
0.11001855132421846,
0.24868097575879394,
0.13389942198760446,
0.3380421666524884,
-0.0194549560546875],
[0.014396791788601412,
0.026696157434030688,
0.13071094018903454,
0.21966229697040063,
0.20331271963151215,
0.3485129395550562,
-0.0194549560546875],
[0.026761144591643495,
0.03305475635776374,
0.16474364778549921,
0.16116127697863467,
0.2962525088315681,
0.3450516911952767,
-0.0194549560546875],
[0.040788309609101125,
0.040698721506989215,
0.19058867605006674,
0.11436381013220927,
0.37466079363808025,
0.33769147455554294,
-0.0194549560546875],
[0.05492529521560714,
0.050642699336754804,
0.20668601689441007,
0.0672495837542147,
0.4817381191952801,
0.321664250089693,
-0.0194549560546875],
[0.07018943364087527,
0.06191335072667651,
0.2208095876804983,
0.016090914566066632,
0.558878027556361,
0.3002506589078401,
-0.0194549560546875],
[0.0890767742911495,
0.07654920164520895,
0.22998926484237925,
-0.04314882393132983,
0.6459622289168823,
0.26693012091434154,
-0.011368132108831466]])


action_min = np.array([-0.1, -0.5, -0.5, -3.0, -3.0, -3.0, -1])
action_max = np.array([0.5, 0.5, 0.5, 3.0, 3.0, 3.0, 4.5])

tokenizer = BSpline_Tokenizer(num_dof=7, num_basis=10, seq_len=20, gripper_zero_order=False).to('cuda')

trajs = torch.from_numpy(traj).unsqueeze(0).to('cuda')
action_min = torch.from_numpy(action_min).unsqueeze(0).to('cuda')
action_max = torch.from_numpy(action_max).unsqueeze(0).to('cuda')

### normalize the trajectory before tokenization
trajs = (trajs - action_min) / (action_max - action_min)
### tokenize the trajectory
tokenizer.visualize_reconstruction_error(trajs)

  from .autonotebook import tqdm as notebook_tqdm


NotImplementedError: 