# Load Data

In [None]:
import numpy as np

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from robotics.model_src.dataset import RobosuiteImageActionDataset

# data_path = "../robomimic/datasets/can/ph/image_griper.hdf5"
data_path = "../robomimic/datasets/tool_hang/ph/image_griper.hdf5"

# camera_type = "robot0_eye_in_hand"
# camera_type = "agentview"
camera_type = None

pred_horizon = 8
obs_horizon = 4

ds = RobosuiteImageActionDataset(data_path, camera_type, obs_horizon = obs_horizon, prediction_horizon = pred_horizon)

In [None]:
import torch

# create dataloader
dataloader = torch.utils.data.DataLoader(
    ds,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['img_obs'].shape)
print("batch['act_obs'].shape:", batch['act_obs'].shape)
print("batch['act_pred'].shape", batch['act_pred'].shape)

In [None]:
from robotics.model_src.diffusion_model import ConditionalUnet1D
from robotics.model_src.visual_encoder import CNNVisualEncoder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# visual_encoder = CLIPVisualEncoder().to(device)

visual_encoder = CNNVisualEncoder().to(device)

vision_feature_dim = visual_encoder.get_output_shape()

action_observation_dim = 7

obs_dim = vision_feature_dim + action_observation_dim

action_dim = 7

noise_prediction_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim * obs_horizon,
).to(device)

In [None]:
image = torch.Tensor(ds[0]["img_obs"][None, :obs_horizon, :, :, :]).to(device)
act_obs = torch.Tensor(ds[0]["act_obs"][None, :obs_horizon, :]).to(device)

In [None]:
from matplotlib import pyplot as plt

im = image[0,0, :, :].cpu().numpy()

plt.imshow(im.transpose((1, 2, 0)))

In [None]:
with torch.no_grad():
    image_features = visual_encoder.encode(image.flatten(start_dim=0, end_dim=1))

    image_features = image_features.reshape(*image.shape[:2], -1)

    obs = torch.cat([image_features, act_obs], dim=-1)

    noised_action = torch.randn((1, pred_horizon, action_dim)).to(device)

    timestep_tensor = torch.randint(0, 101, (1,), device=device)

    noise = noise_prediction_net(
        sample=noised_action,
        timestep=timestep_tensor,
        global_cond=obs.flatten(start_dim=1)
    )

    denoised_action = noised_action - noise


In [None]:
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

num_diffusion_iters = 100

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    clip_sample=True,
    prediction_type='epsilon'
)

In [None]:
from torch import nn
from diffusers import EMAModel, get_scheduler

num_epochs = 400

# EMA params
all_params = list(visual_encoder.parameters()) + list(noise_prediction_net.parameters())
ema = EMAModel(parameters=all_params, power=0.75)

# optimizer
optimizer = torch.optim.AdamW(
    params=all_params,
    lr=1e-4,
    weight_decay=1e-6
)

# LR scheduler
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

# train loop
for epoch_idx in range(num_epochs):
    epoch_loss_sum = 0.0

    for nbatch in dataloader:
        # prepare data
        nimage = nbatch['img_obs'][:, :obs_horizon].to(device)  # (B, H, 3,96,96)
        action_obs = nbatch['act_obs'][:, :obs_horizon].to(device)  # (B, H, 2)
        action_pred = nbatch['act_pred'].to(device)  # (B, P, 2)
        B = action_obs.size(0)

        # forward pass
        image_features = visual_encoder.forward(
            nimage.flatten(end_dim=1)  # (B*H,3,96,96)
        ).reshape(*nimage.shape[:2], -1)  # (B, H, D=512)

        obs_features = torch.cat([image_features, action_obs], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1)  # (B, H*obs_dim)

        noise = torch.randn_like(action_pred)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device).long()

        noisy_actions = noise_scheduler.add_noise(action_pred, noise, timesteps)
        noise_pred = noise_prediction_net(noisy_actions, timesteps, global_cond=obs_cond)

        loss = nn.functional.mse_loss(noise_pred, noise)

        # backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        ema.step(all_params)

        epoch_loss_sum += loss.item()

    avg_loss = epoch_loss_sum / len(dataloader)
    print(f"Epoch {epoch_idx + 1:3d}/{num_epochs} ─ average loss: {avg_loss:.6f}")

# copy EMA weights for inference
ema.copy_to(all_params)

In [None]:
from pathlib import Path


def save_final_models(visual_encoder, noise_pred_net, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "visual_encoder": visual_encoder.state_dict(),
            "noise_pred_net": noise_pred_net.state_dict(),
        },
        out_dir / "model_final.pth",
    )
    print(f"Saved to {out_dir / 'models.pth'}")

In [None]:
save_final_models(visual_encoder, noise_prediction_net,
                  "../models/robot_v5_can_cnn_griper_124")

In [None]:
def load_final_models(visual_encoder, noise_pred_net, ckpt_path, device="cuda"):
    ckpt_path = Path(ckpt_path)
    state = torch.load(ckpt_path, map_location=device)

    visual_encoder.load_state_dict(state["visual_encoder"], strict=True)
    noise_pred_net.load_state_dict(state["noise_pred_net"], strict=True)

    visual_encoder.to(device).eval()
    noise_pred_net.to(device).eval()
    print(f"Loaded weights from {ckpt_path}")


In [None]:
load_final_models(visual_encoder, noise_prediction_net, "../models/v6_success_both_cnn_actions_h1/model_final.pth")

# Inference

In [None]:
import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.utils.vis_utils import depth_to_rgb
from robomimic.envs.env_base import EnvBase, EnvType

from tqdm import tqdm

import os

env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=data_path)

dummy_spec = dict(
    obs=dict(
        low_dim=["robot0_eef_pos"],
        rgb=["agentview_image"]
        # rgb=["robot0_eye_in_hand_image"]
    ),
)

ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)

env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=True, render_offscreen=True, use_image_obs=True)

a = env.reset()

from collections import deque
obs_deque  = deque(maxlen=obs_horizon)
act_deque  = deque(maxlen=obs_horizon)
rewards    = []
imgs       = []
step_idx   = 0

max_steps = 1000
action_horizon  = 4

# ─── 6. Main rollout ──────────────────────────────────────────────────────────
obs = env.reset()
# wrap obs in same format as env.step
obs = obs if isinstance(obs, dict) else obs[0]
for i in range(obs_deque.maxlen):
    obs_deque.append(obs)
    act_deque.append(np.zeros(action_dim, dtype=np.float32))

pbar = tqdm(total=max_steps)
done = False

while not done and step_idx < max_steps:
    # 6.1 build the image & action history tensor
    img_np = np.array([obs_deque[i][camera_type + "_image"] for i in range(obs_deque.maxlen)])

    # img_np = obs_deque[-1]["robot0_eye_in_hand_image"]

    # img_t   = torch.from_numpy(img_np.transpose(0, 3, 1, 2)).unsqueeze(0).float().to(device) / 255.0

    img_t   = torch.from_numpy(img_np.transpose(0, 3, 1, 2)).float().to(device) / 255.0

    actions_hist = torch.stack(
        [torch.from_numpy(a) for a in list(act_deque)],
        dim=0
    ).to(device)                           # (1, H_a, 7)

    # 6.2 compute visual features + conditioning
    with torch.no_grad():
        img_feat = visual_encoder(img_t)                # (1, C)
        obs_cond = torch.cat([img_feat.flatten(start_dim=0).unsqueeze(0) , actions_hist.flatten(start_dim=0).unsqueeze(0)], dim=1)

        # 6.3 sample a future action sequence via diffusion
        B = 1
        pred_actions = torch.randn((B, pred_horizon, action_dim), device=device)
        noise_scheduler.set_timesteps(num_diffusion_iters)
        for t in noise_scheduler.timesteps:
            noise_pred    = noise_prediction_net(pred_actions, t, global_cond=obs_cond)
            out           = noise_scheduler.step(noise_pred, t, pred_actions)
            pred_actions  = out.prev_sample

    pred_actions = pred_actions.cpu().numpy()[0]        # (pred_horizon, 7)

    # 6.4 execute the next block of actions
    start = obs_horizon
    end   = start + action_horizon
    action_block = pred_actions[start:end]          # (5, 7)

    for act in action_block:
        obs, rew, done, info = env.step(act)
        obs = obs if isinstance(obs, dict) else obs[0]

        frame = env.render(mode="rgb_array", height=512, width=512)

        obs_deque.append(obs)
        act_deque.append(act.astype(np.float32))

        rewards.append(rew)
        imgs.append(frame)

        step_idx += 1
        pbar.update(1)
        pbar.set_postfix(reward=float(rew))

        if done or step_idx >= max_steps:
            break

pbar.close()

# ─── 7. Wrap up ───────────────────────────────────────────────────────────────
print(f"Rollout finished: {step_idx} steps, total reward {sum(rewards):.2f}")

In [None]:
import imageio

video_path = "test_2.mp4"
fps = 24

with imageio.get_writer(video_path, fps=fps, codec="libx264") as writer:
    for frame in imgs:
        writer.append_data(frame)

print(f"Saved video to {video_path}")