In [1]:
import sys
# 清空 Notebook 默认的参数，然后添加你需要的参数
sys.argv = ["", "--policy.path=lerobot/pi0", "--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"]

import draccus
from lerobot.configs.train import TrainPipelineConfig, PreTrainedConfig
from lerobot.configs import parser
from lerobot.common.datasets.transforms import ImageTransforms
from lerobot.common.datasets.lerobot_dataset import (
    LeRobotDataset,
    LeRobotDatasetMetadata,
    MultiLeRobotDataset,
)
import math
from lerobot.common.utils.utils import get_safe_dtype
from transformers import (
    AutoConfig,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedModel,
    PaliGemmaForConditionalGeneration,
    GemmaForCausalLM
)
from transformers.models.auto import CONFIG_MAPPING

from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.configs.types import FeatureType
import torch.nn as nn
import torch
from lerobot.common.datasets.utils import cycle
from lerobot.common.policies.normalize import Normalize, Unnormalize
import torch.nn.functional as F  # noqa: N812


def resolve_delta_timestamps(
    cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata
) -> dict[str, list] | None:
    """Resolves delta_timestamps by reading from the 'delta_indices' properties of the PreTrainedConfig.

    Args:
        cfg (PreTrainedConfig): The PreTrainedConfig to read delta_indices from.
        ds_meta (LeRobotDatasetMetadata): The dataset from which features and fps are used to build
            delta_timestamps against.

    Returns:
        dict[str, list] | None: A dictionary of delta_timestamps, e.g.:
            {
                "observation.state": [-0.04, -0.02, 0]
                "observation.action": [-0.02, 0, 0.02]
            }
            returns `None` if the the resulting dict is empty.
    """
    delta_timestamps = {}
    for key in ds_meta.features:
        if key == "next.reward" and cfg.reward_delta_indices is not None:
            delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
        if key == "action" and cfg.action_delta_indices is not None:
            delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
        if key.startswith("observation.") and cfg.observation_delta_indices is not None:
            delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]

    if len(delta_timestamps) == 0:
        delta_timestamps = None

    return delta_timestamps




  from .autonotebook import tqdm as notebook_tqdm
2025-02-14 05:15:09.538384: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-14 05:15:09.781913: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1739538909.881029  213996 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739538909.913509  213996 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-14 05:15:10.159256: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorF

In [2]:
class PaliGemmaWithExpertConfig(PretrainedConfig):
    model_type = "PaliGemmaWithExpertModel"
    sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}

    def __init__(
        self,
        paligemma_config: dict | None = None,
        gemma_expert_config: dict | None = None,
        freeze_vision_encoder: bool = True,
        train_expert_only: bool = True,
        attention_implementation: str = "eager",
        **kwargs,
    ):
        self.freeze_vision_encoder = freeze_vision_encoder
        self.train_expert_only = train_expert_only
        self.attention_implementation = attention_implementation

        if paligemma_config is None:
            # Default config from Pi0
            self.paligemma_config = CONFIG_MAPPING["paligemma"](
                transformers_version="4.48.1",
                _vocab_size=257152,
                bos_token_id=2,
                eos_token_id=1,
                hidden_size=2048,
                image_token_index=257152,
                model_type="paligemma",
                pad_token_id=0,
                projection_dim=2048,
                text_config={
                    "hidden_activation": "gelu_pytorch_tanh",
                    "hidden_size": 2048,
                    "intermediate_size": 16384,
                    "model_type": "gemma",
                    "num_attention_heads": 8,
                    "num_hidden_layers": 18,
                    "num_image_tokens": 256,
                    "num_key_value_heads": 1,
                    "torch_dtype": "float32",
                    "vocab_size": 257152,
                },
                vision_config={
                    "hidden_size": 1152,
                    "intermediate_size": 4304,
                    "model_type": "siglip_vision_model",
                    "num_attention_heads": 16,
                    "num_hidden_layers": 27,
                    "num_image_tokens": 256,
                    "patch_size": 14,
                    "projection_dim": 2048,
                    "projector_hidden_act": "gelu_fast",
                    "torch_dtype": "float32",
                    "vision_use_head": False,
                },
            )
        elif isinstance(self.paligemma_config, dict):
            # Override Pi0 default config for PaliGemma
            if "model_type" not in gemma_expert_config:
                paligemma_config["model_type"] = "paligemma"

            cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
            self.paligemma_config = cfg_cls(**paligemma_config)

        if gemma_expert_config is None:
            # Default config from Pi0
            self.gemma_expert_config = CONFIG_MAPPING["gemma"](
                attention_bias=False,
                attention_dropout=0.0,
                bos_token_id=2,
                eos_token_id=1,
                head_dim=256,
                hidden_act="gelu_pytorch_tanh",
                hidden_activation="gelu_pytorch_tanh",
                hidden_size=1024,
                initializer_range=0.02,
                intermediate_size=4096,
                max_position_embeddings=8192,
                model_type="gemma",
                num_attention_heads=8,
                num_hidden_layers=18,
                num_key_value_heads=1,
                pad_token_id=0,
                rms_norm_eps=1e-06,
                rope_theta=10000.0,
                torch_dtype="float32",
                transformers_version="4.48.1",
                use_cache=True,
                vocab_size=257152,
            )
        elif isinstance(self.gemma_expert_config, dict):
            # Override Pi0 default config for Gemma Expert
            if "model_type" not in gemma_expert_config:
                gemma_expert_config["model_type"] = "gemma"

            cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
            self.gemma_expert_config = cfg_cls(**gemma_expert_config)

        super().__init__(**kwargs)

    def __post_init__(self):
        super().__post_init__()
        if self.train_expert_only and not self.freeze_vision_encoder:
            raise ValueError(
                "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
            )

        if self.attention_implementation not in ["eager", "fa2", "flex"]:
            raise ValueError(
                f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
            )

In [3]:
def apply_rope(x, positions, max_wavelength=20_000):
    print("apply_rope")
    print(f"query states shape: {x.shape}")
    dtype = x.dtype
    d_half = x.shape[-1] // 2
    print(f"d_half: {d_half}")
    x = x.to(torch.float32)
    print(f"x shape: {x.shape}")
    freq_exponents = (2 / x.shape[-1]) * torch.arange(d_half, device=x.device, dtype=x.dtype)
    print(f"freq_exponents shape: {freq_exponents.shape}")
    timescale = max_wavelength**freq_exponents
    print(f"timescale shape: {timescale.shape}")
    print(f"positions shape: {positions.shape}")
    print(f"positions[..., None] shape: {positions[..., None].shape}")
    print(f"timescale[None, None, :] shape: {timescale[None, None, :].shape}")
    radians = positions[..., None].to(torch.float32) * timescale[None, None, :]
    print(f"radians shape: {radians.shape}")
    #x shape: torch.Size([1, 355, 8, 256])
    radians = radians[..., None, :]

    sin = torch.sin(radians)
    cos = torch.cos(radians)

    x1, x2 = x.split(d_half, dim=-1)
    res = torch.empty_like(x)
    res[..., :d_half] = x1 * cos - x2 * sin
    res[..., d_half:] = x1 * sin + x2 * cos

    return res.to(dtype)

In [4]:
class PaliGemmaWithExpertModel(PreTrainedModel):
    config_class = PaliGemmaWithExpertConfig

    def __init__(self, config: PaliGemmaWithExpertConfig):
        super().__init__(config=config)
        self.config = config
        self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
        self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
        self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
        params_to_change_dtype = [
            "language_model.model.layers",
            "gemma_expert.model.layers",
            "vision_tower",
            "multi_modal",
        ]
        for name, param in self.named_parameters():
            if any(selector in name for selector in params_to_change_dtype):
                param.data = param.data.to(dtype=torch.bfloat16)

    def embed_image(self, image):
        return self.paligemma.get_image_features(image)
    
    def embed_lang_tokens(self, lang_tokens):
        return self.paligemma.language_model.model.embed_tokens(lang_tokens)
    
    def forward(self, attention_mask, position_ids, inputs_embeds):
        models = [self.paligemma.language_model.model, self.gemma_expert.model]
        
        num_layers = self.paligemma.config.text_config.num_hidden_layers
        head_dim = self.paligemma.config.text_config.head_dim
        batch_size = inputs_embeds[0].shape[0]
        for layer_idx in range(num_layers):
            query_states = []
            key_states = []
            value_states = []
            for i, hidden_states in enumerate(inputs_embeds):
                layer = models[i].layers[layer_idx]
                print(f"hideen_states.shape: {hidden_states.shape}")
                hidden_states = layer.input_layernorm(hidden_states)
                print(f"hideen_states.shape: {hidden_states.shape}")

                # hideen_states.shape: torch.Size([1, 304, 2048]) prefix
                # hideen_states.shape: torch.Size([1, 51, 1024]) suffix
                input_shape = hidden_states.shape[:-1]
                hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)

                hidden_states = hidden_states.to(dtype=torch.bfloat16)

                print(f"hideen_states.shape: {hidden_states.shape}")

                query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
                key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
                value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)

                print(f"query_state.shape: {query_state.shape}")
                print(f"key_state.shape: {key_state.shape}")
                print(f"value_state.shape: {value_state.shape}")

                query_states.append(query_state)
                key_states.append(key_state)
                value_states.append(value_state)

            query_states = torch.cat(query_states, dim=1)
            key_states = torch.cat(key_states, dim=1)
            value_states = torch.cat(value_states, dim=1)

            query_states = apply_rope(query_states, position_ids)
            print(f"after rope query_states.shape: {query_states.shape}")
            key_states = apply_rope(key_states, position_ids)
            
            att_output = self.eager_attention_forward(attention_mask, batch_size, head_dim, query_states, key_states, value_states)
            print(f"att_output after qkv: {att_output.shape}")

            att_output = att_output.to(dtype=torch.bfloat16)

            outputs_embeds = []
            start = 0
            for i, hidden_states in enumerate(inputs_embeds):
                layer = models[i].layers[layer_idx]
                if hidden_states is not None:
                    end = start + hidden_states.shape[1]
                    out_emb = layer.self_attn.o_proj(att_output[:, start:end])
                    out_emb += hidden_states
                    after_first_residual = out_emb.clone()

                    out_emb = layer.post_attention_layernorm(out_emb)

                    out_emb = layer.mlp(out_emb)

                    out_emb += after_first_residual

                    outputs_embeds.append(out_emb)
                    start = end
                else:
                    outputs_embeds.append(None)

            inputs_embeds = outputs_embeds
        
        outputs_embeds = []
        for i, hidden_states in enumerate(inputs_embeds):
            if hidden_states is not None:
                out_emb = models[i].norm(hidden_states)
                outputs_embeds.append(out_emb)
            else:
                outputs_embeds.append(None)
        
        return outputs_embeds


    def eager_attention_forward(
        self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
    ):
        num_attn_heads = self.paligemma.config.text_config.num_attention_heads
        num_key_value_heads = self.paligemma.config.text_config.num_key_value_heads
        print(f"num_attn_heads: {num_attn_heads}")
        print(f"num_key_value_heads: {num_key_value_heads}")
        num_key_value_groups = num_attn_heads // num_key_value_heads
        print(f"num_key_value_groups: {num_key_value_groups}")

        sequence_length = key_states.shape[1]
        print(f"sequence_length: {sequence_length}")

        print(f"value_states.shape: {value_states.shape}")

        # key_states.shape: torch.Size([1, 304, 1, 256])
        key_states = key_states[:, :, :, None, :].expand(
            batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
        )

        key_states = key_states.reshape(
            batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
        )

        value_states = value_states[:, :, :, None, :].expand(
            batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
        )

        value_states = value_states.reshape(
            batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
        )

        print(f"key_states.shape: {key_states.shape}")
        print(f"query_states.shape: {query_states.shape}")

        query_states = query_states.to(dtype=torch.float32)
        key_states = key_states.to(dtype=torch.float32)

        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)

        # key_states.shape: torch.Size([355, 1, 8, 256])
        # query_states.shape: torch.Size([355, 1, 8, 256])

        att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
        att_weights *= head_dim ** -0.5

        # att_weights.shape: torch.Size([1, 8, 355, 355])
        # attention_mask.shape: torch.Size([1, 355, 355])
        big_neg = -2.3819763e38
        att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
        print(f"att_weights.shape: {att_weights.shape}")
        probs = nn.functional.softmax(att_weights, dim=-1)
        print(f"probs.shape: {probs.shape}")

        probs = probs.to(dtype=torch.bfloat16)
        print(f"value_states.shape: {value_states.shape}")
        attn_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
        print(f"attn_output.shape: {attn_output.shape}")
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
        print(f"attn_output.shape: {attn_output.shape}")

        return attn_output

In [14]:
class Pi0(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        paligemma_with_export_config = PaliGemmaWithExpertConfig(
            freeze_vision_encoder=cfg.policy.freeze_vision_encoder,
            train_expert_only=cfg.policy.train_expert_only,
            attention_implementation=cfg.policy.attention_implementation,
        )
        self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
        self.state_proj = nn.Linear(cfg.policy.max_state_dim, cfg.policy.proj_width).to(torch.bfloat16)
        self.action_proj = nn.Linear(cfg.policy.max_action_dim, cfg.policy.proj_width).to(torch.bfloat16)

        self.action_time_mlp_in = nn.Linear(cfg.policy.proj_width * 2, cfg.policy.proj_width).to(torch.bfloat16)
        self.action_time_mlp_out = nn.Linear(cfg.policy.proj_width, cfg.policy.proj_width).to(torch.bfloat16)
        self.action_out_proj = nn.Linear(cfg.policy.proj_width, cfg.policy.max_action_dim).to(torch.bfloat16)
        self.cfg = cfg

    def create_sinusoidal_pos_embedding(self, 
        time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"):
        """Computes sine-cosine positional embedding vectors for scalar positions."""
        if dimension % 2 != 0:
            raise ValueError(f"dimension ({dimension}) must be divisible by 2")

        if time.ndim != 1:
            raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")

        dtype = get_safe_dtype(torch.float64, device.type)
        fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
        period = min_period * (max_period / min_period) ** fraction

        # Compute the outer product
        scaling_factor = 1.0 / period * 2 * math.pi
        sin_input = scaling_factor[None, :] * time[:, None]
        pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
        return pos_emb

    def embed_prefix(self, images, img_masks, lang_tokens, lang_masks):
        embs = []
        pad_maks = []
        att_masks = []

        for (img, img_mask) in zip(images, img_masks, strict=True):
            img_emb = self.paligemma_with_expert.embed_image(img)
            print("img_emb", img_emb.shape)
            img_emb = img_emb.to(dtype=torch.bfloat16)

            # img_emb torch.Size([1, 256, 2048])
            # Normalize image embeddings
            img_emb_dim = img_emb.shape[-1]
            img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)

            print("img_mask", img_mask.shape)

            b_size, num_img_embds = img_emb.shape[:2]
            img_mask = img_mask[:, None].expand(b_size, num_img_embds)
            print("img_mask", img_mask.shape)

            embs.append(img_emb)
            pad_maks.append(img_mask)

            att_masks += [0] * num_img_embds
        
        lang_emb = self.paligemma_with_expert.embed_lang_tokens(lang_tokens)

        print("lang_emb", lang_emb.shape)

        lang_emb_dim = lang_emb.shape[-1]
        lang_emb = lang_emb * math.sqrt(lang_emb_dim)
        
        print("lang_emb_dim", lang_emb_dim)
        print("lang_masks", lang_masks.shape)

        embs.append(lang_emb)
        pad_maks.append(lang_masks)

        att_masks += [0] * lang_emb.shape[1]

        embs = torch.cat(embs, dim=1)
        pad_maks = torch.cat(pad_maks, dim=1)

        attn_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_maks.device)
        att_masks = attn_masks[None, :].expand(b_size, len(att_masks))

        return embs, pad_maks, att_masks

    def emb_suffix(self, states, noisy_actions, timestep):
        embs = []
        pad_maks = []
        att_masks = []

        state_emb = self.state_proj(states)
        print(f'state_emb {state_emb.shape}')
        state_emb = state_emb.to(dtype=torch.bfloat16)
        embs.append(state_emb[:, None, :])

        state_mask = torch.ones((states.shape[0], 1), dtype=torch.bool, device=state_emb.device)
        pad_maks.append(state_mask)

        att_masks += [1]

        time_emb = self.create_sinusoidal_pos_embedding(timestep, self.cfg.policy.proj_width, min_period=0.001, max_period=4.0, device=state_emb.device) 
        time_emb = time_emb.to(dtype=state_emb.dtype)

        action_emb = self.action_proj(noisy_actions)
        print(f'action_emb {action_emb.shape}')
        print(f'time_emb {time_emb.shape}')

        time_emb = time_emb[:, None, :].expand(-1, action_emb.shape[1], -1)
        action_time_emb = torch.cat([action_emb, time_emb], dim=-1)
        print(f'action_time_emb {action_time_emb.shape}')

        action_time_emb = self.action_time_mlp_in(action_time_emb)
        action_time_emb = F.silu(action_time_emb)
        action_time_emb = self.action_time_mlp_out(action_time_emb)

        embs.append(action_time_emb)

        action_mask = torch.ones(action_time_emb.shape[0], action_time_emb.shape[1], dtype=torch.bool, device=action_time_emb.device)
        pad_maks.append(action_mask)

        att_masks += [1] + ([0] * (self.cfg.policy.n_action_steps - 1))

        print(f"embs len {len(embs)}")
        for i, emb in enumerate(embs):
            print(f"Embedding {i} shape: {emb.shape}")

        print(f"action_mask len {len(pad_maks)}")
        for i, emb in enumerate(pad_maks):
            print(f"action_mask {i} shape: {emb.shape}")

        att_masks = torch.tensor(att_masks, dtype=action_time_emb.dtype, device=action_time_emb.device)

        print(f"att_masks shape {att_masks.shape}")
        
        embs = torch.cat(embs, dim=1)
        pad_maks = torch.cat(pad_maks, dim=1)
        att_masks = att_masks[None, :].expand(embs.shape[0], len(att_masks))

        return embs, pad_maks, att_masks



    
    def forward(self, images, img_masks, lang_tokens, lang_masks, states, actions):
        noise = torch.normal(
            mean=0.0, std=1.0, size=actions.shape,
            dtype=torch.bfloat16, device=actions.device,
        )
        gamma1 = torch.empty((actions.shape[0],), device=actions.device).uniform_(0, 1).pow(1 / 1.5)
        gamma2 = torch.empty((actions.shape[0],), device=actions.device).uniform_(0, 1).pow(1 / 1.0)  
        time_beta = gamma1 / (gamma1 + gamma2)
        time = time_beta * 0.999 + 0.001
        time = time.to(dtype=torch.bfloat16, device=actions.device)

        time_expanded = time[:, None, None]
        print(time_expanded.shape)
        print(noise.shape)
        print(actions.shape)
        x_t = time_expanded * noise + (1 - time_expanded) * actions
        u_t = noise - actions

        prefix_embs, prefix_pad_masks, prefix_att_masks =  self.embed_prefix(
            images, img_masks, lang_tokens, lang_masks
        )

        suffix_embs, suffix_pad_masks, suffix_att_masks =  self.emb_suffix(states, x_t, time)


        print(f"prefix_embs {prefix_embs.shape}")
        print(f"prefix_pad_masks {prefix_pad_masks.shape}")
        print(f"prefix_att_masks {prefix_att_masks.shape}")
        print(f"suffix_embs {suffix_embs.shape}")
        print(f"suffix_pad_masks {suffix_pad_masks.shape}")
        print(f"suffix_att_masks {suffix_att_masks.shape}")

        pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
        att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)

        cumsum = torch.cumsum(att_masks, dim=1)
        att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
        pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
        att_2d_masks = att_2d_masks * pad_2d_masks
        
        position_ids = torch.cumsum(pad_masks, dim=1) - 1
        
        (_, outputs_embeds) = self.paligemma_with_expert.forward(
            attention_mask=att_2d_masks,
            position_ids=position_ids,
            inputs_embeds=[prefix_embs, suffix_embs],
        )

        #see the shape of output embeds
        for i, emb in enumerate(outputs_embeds):
            print(f"Output Embedding {i} shape: {emb.shape}")

        suffix_out = outputs_embeds[:,-self.cfg.policy.n_action_steps:]
        print(f"suffix_out shape: {suffix_out.shape}")

        suffix_out = suffix_out.to(dtype=torch.bfloat16)
        v_t = self.action_out_proj(suffix_out)
        print(f"v_t shape: {v_t.shape}")

        losses = F.mse_loss(v_t, u_t, reduction="none")
        print(f"losses: {losses.mean()}")

        return losses

In [15]:
class ModelRunner(nn.Module):
    def __init__(self, cfg, ds_meta):
        super().__init__()
        self.cfg = cfg
        self.normalize_inputs = Normalize(cfg.input_features, cfg.policy.normalization_mapping, ds_meta.stats)
        self.normalize_targets = Normalize(
            cfg.output_features, cfg.policy.normalization_mapping, ds_meta.stats
        )
        self.unnormalize_outputs = Unnormalize(
            cfg.output_features, cfg.policy.normalization_mapping, ds_meta.stats
        )
        self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
        self.model = Pi0(cfg)

    def resize_with_pad(sefl, img, width, height, pad_value=-1):
        # assume no-op when width height fits already
        if img.ndim != 4:
            raise ValueError(f"(b,c,h,w) expected, but {img.shape}")

        cur_height, cur_width = img.shape[2:]
        #img shape: torch.Size([1, 3, 480, 640])
  
        ratio = max(cur_width / width, cur_height / height)
        resized_height = int(cur_height / ratio)
        resized_width = int(cur_width / ratio)
        resized_img = F.interpolate(
            img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
        )

        pad_height = max(0, int(height - resized_height))
        pad_width = max(0, int(width - resized_width))
        

        # pad on left and top of image
        padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)

        return padded_img

    def prepare_image(self, batch):
        images = []
        img_masks = []
        present_img_keys = [key for key in self.cfg.image_features if key in batch]

        for key in present_img_keys:
            print(f"processing img key: {key}")
            img = batch[key]
            if self.cfg.policy.resize_imgs_with_padding is not None:
                img = self.resize_with_pad(img, *self.cfg.policy.resize_imgs_with_padding, pad_value=0)
                print(f"img shape: {img.shape}")

            print(f"img: {img}")
            img = img * 2.0 - 1.0
            print(f"img: {img}")

            mask = torch.ones(img.shape[0], dtype=torch.bool, device=img.device)
            images.append(img)
            img_masks.append(mask)

        return images, img_masks

    def prepare_language(self, batch):
        tasks = batch["task"]
        device = batch["observation.state"].device
        print(tasks)
        tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
        print(tasks)
        print(self.cfg.policy.tokenizer_max_length)

        tokenized_prompt = self.language_tokenizer.__call__(
            tasks,
            padding="max_length",
            padding_side="right",
            max_length=self.cfg.policy.tokenizer_max_length,
            return_tensors="pt",
        )

        lang_tokens = tokenized_prompt["input_ids"].to(device=device)
        lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)

        return lang_tokens, lang_masks
    
    def pad_vector(self, vector, new_dim):
        shape = list(vector.shape)
        curren_dim = shape[-1]
        shape[-1] = new_dim
        new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
        new_vector[..., :curren_dim] = vector
        # paded_vector = torch.zeros(vector.shape[0], new_dim, device=vector.device)
        # paded_vector[:, :vector.shape[1]] = vector
        return new_vector


    def prepare_action(self, batch):
        actions = self.pad_vector(batch["action"], self.cfg.policy.max_action_dim)
        return actions
    
    def prepare_state(self, batch):
        states = self.pad_vector(batch["observation.state"], self.cfg.policy.max_state_dim)
        return states

    def forward(self, batch):
        print(f"before normalize: {batch}")
        batch = self.normalize_inputs(batch)
        batch = self.normalize_targets(batch)
        print(f"after normalize: {batch}")

        images, img_masks = self.prepare_image(batch)
        lang_tokens, lang_masks = self.prepare_language(batch)
        actions = self.prepare_action(batch)
        states = self.prepare_state(batch)
        
        for img in images:
            img = img.to(torch.bfloat16)
        for img_mask in img_masks:
            img_mask = img_mask.to(torch.bfloat16)
        for lang_tk in lang_tokens:
            lang_tk = lang_tk.to(torch.bfloat16)
        for lang_mask in lang_masks:
            lang_mask = lang_mask.to(torch.bfloat16)
        actions = actions.to(torch.bfloat16)        
        states = states.to(torch.bfloat16)

        self.model.forward(images, img_masks, lang_tokens, lang_masks, states, actions)

        return batch.keys()

In [16]:
@parser.wrap()
def train(cfg: TrainPipelineConfig):
    print("Before from_pretrained call, cfg is:", cfg)
    cfg.validate()
    print("after from_pretrained call, cfg is:", cfg)

    image_transforms = (
        ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
    )

    ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)

    features = dataset_to_policy_features(ds_meta.features)

    cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
    cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
    print(f'cfg.input_features is {cfg.input_features}')
    delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)

    cfg.image_features = {key: ft for key, ft in cfg.input_features.items() if ft.type is FeatureType.VISUAL}
    print(f'cfg.image_features is {cfg.image_features}')

    dataset = LeRobotDataset(
        cfg.dataset.repo_id,
        episodes=cfg.dataset.episodes,
        delta_timestamps=delta_timestamps,
        image_transforms=image_transforms,
        video_backend=cfg.dataset.video_backend,
        local_files_only=cfg.dataset.local_files_only,
    )

    dataloader = torch.utils.data.DataLoader(
        dataset,
        num_workers=4,
        batch_size=1,
        shuffle=False,
        pin_memory="gpu",
    )
    dl_iter = cycle(dataloader)
    
    model = ModelRunner(cfg, ds_meta)
    for _ in range(1):
        batch = next(dl_iter)
        output = model(batch)
        # print(output)



train()

argspec is FullArgSpec(args=['cfg'], varargs=None, varkw=None, defaults=None, kwonlyargs=[], kwonlydefaults=None, annotations={'cfg': <class 'lerobot.configs.train.TrainPipelineConfig'>})
argtype is <class 'lerobot.configs.train.TrainPipelineConfig'>
cli_args is ['--policy.path=lerobot/pi0', '--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human']
config_path_cli is None
path_fields is ['policy']
cli_args is ['--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human']
argtype is <class 'lerobot.configs.train.TrainPipelineConfig'>
config_path is None
cli_args is ['--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human']




cfg is TrainPipelineConfig(dataset=DatasetConfig(repo_id='lerobot/aloha_sim_transfer_cube_human', episodes=None, image_transforms=ImageTransformsConfig(enable=False, max_num_transforms=3, random_order=False, tfs={'brightness': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'brightness': (0.8, 1.2)}), 'contrast': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'contrast': (0.8, 1.2)}), 'saturation': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'saturation': (0.5, 1.5)}), 'hue': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'hue': (-0.05, 0.05)}), 'sharpness': ImageTransformConfig(weight=1.0, type='SharpnessJitter', kwargs={'sharpness': (0.5, 1.5)})}), local_files_only=False, use_imagenet_stats=True, video_backend='pyav'), env=None, policy=None, output_dir=None, job_name=None, resume=False, device=None, use_amp=False, seed=1000, num_workers=4, batch_size=4, eval_freq=20000, log_freq=200, save_checkpoint=True, save_freq=20, offline

Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 987.24it/s]


cfg.input_features is {'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(14,))}
cfg.image_features is {'observation.images.top': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 480, 640))}


Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 3947.58it/s]
Fetching 106 files: 100%|██████████| 106/106 [00:00<00:00, 1253.62it/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/t

before normalize: {'observation.images.top': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]), 'observation.state': tensor([[ 0.0000, -0.9600,  1.1600,  0.0000, -0.3000,  0.0000,  0.0000,  0.0000,
         -0.9600,  1.1600,  0.0000, -0.3000,

after from_pretrained call, cfg is: TrainPipelineConfig(dataset=DatasetConfig(repo_id='lerobot/aloha_sim_transfer_cube_human', episodes=None, image_transforms=ImageTransformsConfig(enable=False, max_num_transforms=3, random_order=False, tfs={'brightness': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'brightness': (0.8, 1.2)}), 'contrast': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'contrast': (0.8, 1.2)}), 'saturation': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'saturation': (0.5, 1.5)}), 'hue': ImageTransformConfig(weight=1.0, type='ColorJitter', kwargs={'hue': (-0.05, 0.05)}), 'sharpness': ImageTransformConfig(weight=1.0, type='SharpnessJitter', kwargs={'sharpness': (0.5, 1.5)})}), local_files_only=False, use_imagenet_stats=True, video_backend='pyav'), env=None, policy=PI0Config(n_obs_steps=1, normalization_mapping={'VISUAL': <NormalizationMode.IDENTITY: 'IDENTITY'>, 'STATE': <NormalizationMode.MEAN_STD: 'MEAN_STD'>, 'ACTION': <NormalizationMode.MEAN_STD: 'MEAN_STD'>}, input_features={}, output_features={}, chunk_size=50, n_action_steps=50, max_state_dim=32, max_action_dim=32, resize_imgs_with_padding=(224, 224), empty_cameras=0, adapt_to_pi_aloha=False, use_delta_joint_actions_aloha=False, tokenizer_max_length=48, proj_width=1024, num_steps=10, use_cache=True, attention_implementation='eager', freeze_vision_encoder=True, train_expert_only=False, train_state_proj=True, optimizer_betas=(0.9, 0.95), optimizer_eps=1e-08, optimizer_type='sgd', optimizer_lr=2.5e-05, optimizer_momentum=0.9, optimizer_weight_decay=1e-10, scheduler_warmup_steps=1000, scheduler_decay_steps=30000, scheduler_decay_lr=2.5e-06), output_dir=PosixPath('outputs/train/2025-02-10/23-39-54_pi0'), job_name='pi0', resume=False, device='cuda', use_amp=False, seed=1000, num_workers=4, batch_size=4, eval_freq=20000, log_freq=200, save_checkpoint=True, save_freq=20, offline=OfflineConfig(steps=100000), online=OnlineConfig(steps=0, rollout_n_episodes=1, rollout_batch_size=1, steps_between_rollouts=None, sampling_ratio=0.5, env_seed=None, buffer_capacity=None, buffer_seed_size=0, do_rollout_async=False), use_policy_training_preset=True, optimizer=SGDConfig(lr=2.5e-05, weight_decay=1e-10, grad_clip_norm=10.0, momentum=0.9, dampening=0.0, nesterov=False), scheduler=CosineDecayWithWarmupSchedulerConfig(num_warmup_steps=1000, num_decay_steps=30000, peak_lr=2.5e-05, decay_lr=2.5e-06), eval=EvalConfig(n_episodes=50, batch_size=50, use_async_envs=False), wandb=WandBConfig(enable=False, disable_artifact=False, project='lerobot', entity=None, notes=None))


In [10]:
config = draccus.parse(PreTrainedConfig, config_file, [])

In [None]:
config