In [1]:
import numpy as np

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from robotics.model_src.dataset import RobosuiteImageActionDataset

data_path = "../robomimic/datasets/tool_hang/ph/image_griper.hdf5"

camera_type = None

pred_horizon = 8
obs_horizon = 4

ds = RobosuiteImageActionDataset(data_path, camera_type, obs_horizon = obs_horizon, prediction_horizon = pred_horizon)

100%|██████████| 200/200 [00:00<00:00, 4709.79it/s]
100%|██████████| 200/200 [00:00<00:00, 42532.11it/s]


In [4]:
import torch


val_ratio = 0.1
n_total   = len(ds)
n_val     = int(n_total * val_ratio)
n_train   = n_total - n_val

generator = torch.Generator().manual_seed(33)
train_set, val_set = torch.utils.data.random_split(
        ds, [n_train, n_val], generator=generator)


train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=512, shuffle=True,
    num_workers=4, pin_memory=True, persistent_workers=True)

val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=512, shuffle=False,
    num_workers=4, pin_memory=True, persistent_workers=True)


# visualize data in batch
batch = next(iter(train_loader))
print("batch['image'].shape:", batch['img_obs'].shape)
print("batch['act_obs'].shape:", batch['act_obs'].shape)
print("batch['act_pred'].shape", batch['act_pred'].shape)

batch['image'].shape: torch.Size([512, 5, 58])
batch['act_obs'].shape: torch.Size([512, 5, 7])
batch['act_pred'].shape torch.Size([512, 8, 7])


In [5]:
from robotics.model_src.diffusion_model import ConditionalUnet1D, ConditionalUnet1DTransformer
from robotics.model_src.visual_encoder import CNNVisualEncoder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

obs_feature_dim = ds.obs_shape[0]

action_observation_dim = 7

# obs_dim = obs_feature_dim + action_observation_dim

obs_dim = obs_feature_dim

action_dim = 7

noise_prediction_net = ConditionalUnet1DTransformer(
    input_dim=action_dim,
    global_cond_dim=obs_dim * obs_horizon,
).to(device)

  from .autonotebook import tqdm as notebook_tqdm


number of parameters: 9.124929e+07


In [6]:
model_size = 9.163636e+07

In [7]:
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

num_diffusion_iters = 100

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    clip_sample=True,
    prediction_type='epsilon'
)

In [12]:
from torch import nn
from diffusers import EMAModel, get_scheduler

def forward_loss(nbatch):
    nobs  = nbatch['img_obs'][:, :obs_horizon].to(device)
    a_obs = nbatch['act_obs'][:, :obs_horizon].to(device)
    a_gt  = nbatch['act_pred'].to(device)
    B = a_obs.size(0)


    obs_features = nobs
    # obs_features = torch.cat([nobs, a_obs], dim=-1)
    obs_cond = obs_features.flatten(start_dim=1)

    noise = torch.randn_like(a_gt)
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
                               (B,), device=device).long()
    noisy_a = noise_scheduler.add_noise(a_gt, noise, timesteps)
    noise_pred = noise_prediction_net(noisy_a, timesteps, global_cond=obs_cond)
    return nn.functional.mse_loss(noise_pred, noise)



num_epochs = 200

# EMA params
all_params = list(noise_prediction_net.parameters())
ema = EMAModel(parameters=all_params, power=0.75)

# optimizer
optimizer = torch.optim.AdamW(
    params=all_params,
    lr=1e-4,
    weight_decay=1e-6
)

# LR scheduler
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(train_loader) * num_epochs
)


# train loop
train_hist, val_hist = [], []
for epoch_idx in range(num_epochs):
    epoch_loss_sum = 0.0

    for nbatch in train_loader:
        # prepare data
        nobs = nbatch['img_obs'][:, :obs_horizon].to(device)  # (B, H, state_len)
        action_obs = nbatch['act_obs'][:, :obs_horizon].to(device)  # (B, H, 7)
        action_pred = nbatch['act_pred'].to(device)  # (B, P, 7)
        B = action_obs.size(0)

        # obs_features = torch.cat([nobs, action_obs], dim=-1)
        obs_features = nobs

        obs_cond = obs_features.flatten(start_dim=1)  # (B, H*obs_dim)

        noise = torch.randn_like(action_pred)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device).long()

        noisy_actions = noise_scheduler.add_noise(action_pred, noise, timesteps)
        noise_pred = noise_prediction_net(noisy_actions, timesteps, global_cond=obs_cond)

        loss = nn.functional.mse_loss(noise_pred, noise)

        # backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        ema.step(all_params)

        epoch_loss_sum += loss.item()

        avg_train = epoch_loss_sum / len(train_loader)
        train_hist.append(avg_train)

    # validation
    noise_prediction_net.eval()
    val_sum = 0.0
    with torch.no_grad():
        for batch in val_loader:
            val_sum += forward_loss(batch).item()
    avg_val = val_sum / len(val_loader)
    val_hist.append(avg_val)

    print(f"Epoch {epoch_idx+1:03d}/{num_epochs} | "
          f"train {avg_train:.6f} | val {avg_val:.6f}")

# copy EMA weights for inference
ema.copy_to(all_params)

Epoch 001/200 | train 0.021189 | val 0.027548
Epoch 002/200 | train 0.025219 | val 0.031022
Epoch 003/200 | train 0.027860 | val 0.032817
Epoch 004/200 | train 0.028724 | val 0.032325
Epoch 005/200 | train 0.028734 | val 0.033084
Epoch 006/200 | train 0.029250 | val 0.030723
Epoch 007/200 | train 0.028803 | val 0.031742
Epoch 008/200 | train 0.029191 | val 0.031183
Epoch 009/200 | train 0.028572 | val 0.033228
Epoch 010/200 | train 0.029057 | val 0.031237
Epoch 011/200 | train 0.029011 | val 0.031274
Epoch 012/200 | train 0.028611 | val 0.032112
Epoch 013/200 | train 0.028974 | val 0.032849
Epoch 014/200 | train 0.028932 | val 0.031709
Epoch 015/200 | train 0.029081 | val 0.032740
Epoch 016/200 | train 0.029192 | val 0.032069
Epoch 017/200 | train 0.028338 | val 0.031463
Epoch 018/200 | train 0.029103 | val 0.030677
Epoch 019/200 | train 0.028265 | val 0.031312
Epoch 020/200 | train 0.028440 | val 0.031985
Epoch 021/200 | train 0.028636 | val 0.031185
Epoch 022/200 | train 0.027919 | v

In [None]:
from pathlib import Path


def save_final_models(noise_pred_net, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "noise_pred_net": noise_pred_net.state_dict(),
        },
        out_dir / "model_final.pth",
    )
    print(f"Saved to {out_dir / 'models.pth'}")


save_final_models(noise_prediction_net,
                  "../models/robot_transformer_v7_91_tool_hand_state")

In [None]:
import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.utils.vis_utils import depth_to_rgb
from robomimic.envs.env_base import EnvBase, EnvType

from tqdm import tqdm

import os

env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=data_path)

dummy_spec = dict(
    obs=dict(
        low_dim=["robot0_eef_pos"],
        rgb=["agentview_image"]
        # rgb=["robot0_eye_in_hand_image"]
    ),
)

ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)

env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=True, render_offscreen=True, use_image_obs=True)

a = env.reset()

from collections import deque
obs_deque  = deque(maxlen=obs_horizon)
act_deque  = deque(maxlen=obs_horizon)
rewards    = []
imgs       = []
step_idx   = 0

max_steps = 1000
action_horizon  = 2

# ─── 6. Main rollout ──────────────────────────────────────────────────────────
obs = env.reset()
state_vec = env.get_state()["states"]
# wrap obs in same format as env.step
for i in range(obs_deque.maxlen):
    obs_deque.append(state_vec)
    act_deque.append(np.zeros(action_dim, dtype=np.float32))

pbar = tqdm(total=max_steps)
done = False

while not done and step_idx < max_steps:
    # 6.1 build the image & action history tensor
    state_np = np.array([obs_deque[i] for i in range(obs_deque.maxlen)])

    actions_hist = torch.stack(
        [torch.from_numpy(a) for a in list(act_deque)],
        dim=0
    ).to(device)                           # (1, H_a, 7)

    # 6.2 compute visual features + conditioning
    with torch.no_grad():
        state_feat = torch.Tensor(state_np).to(device)                # (1, C)
        # obs_features = torch.cat([state_feat, actions_hist], dim=-1)
        obs_features = state_feat
        obs_cond = obs_features.flatten(start_dim=0).unsqueeze(0)

        # 6.3 sample a future action sequence via diffusion
        B = 1
        pred_actions = torch.randn((B, pred_horizon, action_dim), device=device)
        noise_scheduler.set_timesteps(num_diffusion_iters)
        for t in noise_scheduler.timesteps:
            noise_pred    = noise_prediction_net(pred_actions, t, global_cond=obs_cond)
            out           = noise_scheduler.step(noise_pred, t, pred_actions)
            pred_actions  = out.prev_sample

    pred_actions = pred_actions.cpu().numpy()[0]        # (pred_horizon, 7)

    # 6.4 execute the next block of actions
    start = obs_horizon
    end   = start + action_horizon
    action_block = pred_actions[start:end]          # (5, 7)

    for act in action_block:
        obs, rew, done, info = env.step(act)
        obs = obs if isinstance(obs, dict) else obs[0]

        state = env.get_state()["states"]

        frame = env.render(mode="rgb_array", height=512, width=512)

        obs_deque.append(state)
        act_deque.append(act.astype(np.float32))

        rewards.append(rew)
        imgs.append(frame)

        step_idx += 1
        pbar.update(1)
        pbar.set_postfix(reward=float(rew))

        if done or step_idx >= max_steps:
            break

pbar.close()

print(f"Rollout finished: {step_idx} steps, total reward {sum(rewards):.2f}")



using obs modality: low_dim with keys: ['robot0_eef_pos']
using obs modality: rgb with keys: ['agentview_image']


INFO: Probing, EGL cannot run on this device


Created environment with name ToolHang
Action size is 7


  4%|▍         | 40/1000 [00:11<04:20,  3.69it/s, reward=0]

In [14]:
import imageio

video_path = "test_state.mp4"
fps = 24

with imageio.get_writer(video_path, fps=fps, codec="libx264") as writer:
    for frame in imgs:
        writer.append_data(frame)

print(f"Saved video to {video_path}")

Saved video to test_state.mp4
