# Load Data

In [2]:
import numpy as np
import torch

torch.cuda.is_available()

True

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import h5py

generator = torch.Generator().manual_seed(33)

persistent_workers = False
batch_size = 128

camera_type = ["agentview", "robot0_eye_in_hand"]

pred_horizon = 8
obs_horizon = 1

data_path = "../robomimic/datasets/tool_hang/ph/image_agent_hand.hdf5"

In [None]:
# from robotics.model_src.dataset import preprocess_images_in_place
#
# preprocess_images_in_place(
#     h5_path=data_path,
#     cameras=camera_type,
#     target_dtype=np.float32,
# )

In [5]:
import random
import torch
import gc
import ctypes
from torch.utils.data import DataLoader
from robotics.model_src.dataset import RobosuiteImageActionDataset


class EnsembleLoader:
    """
    A helper that keeps several (Dataset, DataLoader) pairs and exposes a single
    “active” pair. Call rotate() to switch the active pair and release memory
    used by the previous one.

        mgr = DemoDataManager(...)
        ds  = mgr.get_ds()           # active dataset
        ld  = mgr.get_loader()       # active dataloader
        mgr.rotate()                 # switch to next pair
        idx = mgr.get_active()       # index of the active pair
    """

    @staticmethod
    def _build_ds(data_path, camera, obs_h, pred_h, demo_subset):
        return RobosuiteImageActionDataset(
            data_path,
            camera,
            obs_horizon=obs_h,
            pred_horizon=pred_h,
            demos=demo_subset,
        )

    @staticmethod
    def _build_loader(ds, batch_size, num_workers, shuffle, gen, persistent):
        return DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            generator=gen if shuffle else None,
            num_workers=num_workers,
            pin_memory=True,
            persistent_workers=persistent,
        )

    def __init__(
        self,
        data_path: str,
        chunk_size: int = 30,
        validation_size: int = 20,
        batch_size: int = 256,
        obs_horizon: int = 1,
        pred_horizon: int = 8,
        camera: list[str] = ["agentview"],
        num_workers: int = 4,
        persistent_workers: bool = False,
        seed: int = 33,
    ):
        f = h5py.File(data_path, "r")

        data = f["data"]

        demos = list(data.keys())

        g = torch.Generator().manual_seed(seed)

        # validation set
        self.val_ds = self._build_ds(
            data_path, camera, obs_horizon, pred_horizon, demos[-validation_size:]
        )
        self.val_loader = self._build_loader(
            self.val_ds, batch_size, num_workers, False, g, True
        )

        # training sets divided into chunks
        train_demos = demos[:-validation_size]
        self.train_pairs = []
        for i in range(0, len(train_demos), chunk_size):
            subset = train_demos[i : i + chunk_size]
            ds = self._build_ds(data_path, camera, obs_horizon, pred_horizon, subset)
            ld = self._build_loader(ds, batch_size, num_workers, True, g, persistent_workers)
            self.train_pairs.append((ds, ld))
            ds.drop_data()  # keep RAM usage low

        self.active = 0
        self.train_pairs[0][0].load_data()

    def rotate(self, use_random = False):
        prev = self.active

        # stop workers and drop current data
        self._stop_workers(self.train_pairs[prev][1])
        self.train_pairs[prev][0].drop_data()

        # move to next index
        if use_random:
            self.active = random.randint(0, len(self.train_pairs) - 1)
        else:
            self.active = (self.active + 1) % len(self.train_pairs)
        print(f"rotating from {prev} to {self.active} dataset")

        # load data for the new active dataset
        self.train_pairs[self.active][0].load_data()

        self._trim()

    def get_ds(self, idx: int | None = None):
        if idx is None:
            idx = self.active
        return self.train_pairs[idx][0]

    def get_loader(self, idx: int | None = None):
        if idx is None:
            idx = self.active
        return self.train_pairs[idx][1]

    def get_val_loader(self, idx: int | None = None):
        return self.val_loader

    def get_active_loader(self) -> int:
        return self.get_loader(self.active)

    @staticmethod
    def _stop_workers(loader):
        it = getattr(loader, "_iterator", None)
        if it is not None:
            it._shutdown_workers()
            loader._iterator = None

    @staticmethod
    def _trim():
        gc.collect()
        try:
            ctypes.CDLL("libc.so.6").malloc_trim(0)
        except OSError:
            pass


In [6]:
dataset = EnsembleLoader(data_path, batch_size=128, camera = camera_type, validation_size= 10, chunk_size=6, persistent_workers=False)

10it [00:07,  1.26it/s]
6it [00:07,  1.21s/it]
6it [00:04,  1.32it/s]
6it [00:04,  1.22it/s]
6it [00:04,  1.27it/s]
6it [00:04,  1.37it/s]
6it [00:05,  1.20it/s]
6it [00:04,  1.34it/s]
6it [00:05,  1.03it/s]
6it [00:04,  1.41it/s]
6it [00:04,  1.24it/s]
6it [00:06,  1.00s/it]
6it [00:05,  1.13it/s]
6it [00:04,  1.22it/s]
6it [00:05,  1.09it/s]
6it [00:04,  1.29it/s]
6it [00:05,  1.12it/s]
6it [00:05,  1.08it/s]
6it [00:04,  1.38it/s]
6it [00:05,  1.03it/s]
6it [00:06,  1.09s/it]
6it [00:05,  1.12it/s]
6it [00:05,  1.07it/s]
6it [00:06,  1.01s/it]
6it [00:05,  1.02it/s]
6it [00:06,  1.02s/it]
6it [00:05,  1.16it/s]
6it [00:06,  1.03s/it]
6it [00:05,  1.11it/s]
6it [00:04,  1.24it/s]
6it [00:04,  1.23it/s]
6it [00:04,  1.31it/s]
4it [00:03,  1.18it/s]
6it [00:06,  1.12s/it]


In [None]:
from robotics.model_src.diffusion_model import ConditionalUnet1D, ConditionalUnet1DTransformer
from robotics.model_src.visual_encoder import CNNVisualEncoder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# visual_encoder = CLIPVisualEncoder().to(device)

visual_agentview_encoder = CNNVisualEncoder().to(device)

visual_eye_in_hand_encoder = CNNVisualEncoder().to(device)

vision_feature_dim = visual_agentview_encoder.get_output_shape() + visual_eye_in_hand_encoder.get_output_shape()

action_observation_dim = 7

obs_dim = vision_feature_dim + action_observation_dim

action_dim = 7

noise_prediction_net = ConditionalUnet1DTransformer(
    input_dim=action_dim,
    global_cond_dim=obs_dim * obs_horizon,
).to(device)

In [None]:
# image = torch.Tensor(dataset.val_ds[0]["img_obs"][None, :obs_horizon, :, :, :]).to(device)
#
# image_agent = image[0, 0, 0]
#
# image_hand = image[0, 0, 1]

In [None]:
# from matplotlib import pyplot as plt
#
# im = image_agent.cpu().numpy()
#
# plt.imshow(im.transpose((1, 2, 0)))
#
# plt.show()
#
#
# im = image_hand.cpu().numpy()
#
# plt.imshow(im.transpose((1, 2, 0)))
#
# plt.show()

In [None]:
# with torch.no_grad():
#
#     image_agent = image[:, : , 0]
#     image_hand = image[:, : , 1]
#
#     image_features_agent = visual_agentview_encoder.encode(image_agent.flatten(start_dim=0, end_dim=1))
#
#     image_features_hand = visual_eye_in_hand_encoder.encode(image_hand.flatten(start_dim=0, end_dim=1))
#
#     image_features = torch.cat([image_features_agent, image_features_hand], dim=-1)
#
#     obs = image_features
#
#     noised_action = torch.randn((1, pred_horizon, action_dim)).to(device)
#
#     timestep_tensor = torch.randint(0, 101, (1,), device=device)
#
#     noise = noise_prediction_net(
#         sample=noised_action,
#         timestep=timestep_tensor,
#         global_cond=obs.flatten(start_dim=1)
#     )
#
#     denoised_action = noised_action - noise

In [None]:
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

num_diffusion_iters = 100

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    clip_sample=True,
    prediction_type='epsilon'
)

In [None]:
from pathlib import Path


def save_final_models(visual_encoder_agent, visual_encoder_hand, noise_pred_net, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "visual_encoder_agent": visual_encoder_agent.state_dict(),
            "visual_encoder_hand": visual_encoder_hand.state_dict(),
            "noise_pred_net": noise_pred_net.state_dict(),
        },
        out_dir / "model_final.pth",
    )
    print(f"Saved to {out_dir / 'models.pth'}")

def load_final_models(visual_encoder, noise_pred_net, ckpt_path, device="cuda"):
    ckpt_path = Path(ckpt_path)
    state = torch.load(ckpt_path, map_location=device)

    visual_encoder.load_state_dict(state["visual_encoder"], strict=True)
    noise_pred_net.load_state_dict(state["noise_pred_net"], strict=True)

    visual_encoder.to(device).eval()
    noise_pred_net.to(device).eval()
    print(f"Loaded weights from {ckpt_path}")

def save_checkpoint(
    epoch,
    loss,
    visual_encoder,
    noise_pred_net,
    ema,
    optimizer,
    scheduler,
    out_dir="checkpoints",
):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_name = f"checkpoint_epoch{epoch:03d}_loss{loss:.4f}.pth"
    torch.save(
        {
            "epoch": epoch,
            "loss": loss,
            "visual_encoder": visual_encoder.state_dict(),
            "noise_pred_net": noise_pred_net.state_dict(),
            "ema": ema.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        out_dir / ckpt_name,
    )
    print(f"Checkpoint saved to {out_dir / ckpt_name}")

def load_checkpoint(
    ckpt_path,
    visual_encoder,
    noise_pred_net,
    ema,
    optimizer=None,
    scheduler=None,
    map_location="cpu",
):
    ckpt = torch.load(ckpt_path, map_location=map_location)
    visual_encoder.load_state_dict(ckpt["visual_encoder"])
    noise_pred_net.load_state_dict(ckpt["noise_pred_net"])
    ema.load_state_dict(ckpt["ema"])
    if optimizer is not None and "optimizer" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler is not None and "scheduler" in ckpt:
        scheduler.load_state_dict(ckpt["scheduler"])
    print(f"Checkpoint loaded from {ckpt_path}")
    return ckpt["epoch"], ckpt.get("loss", None)


In [None]:
from torch import nn
from diffusers import EMAModel, get_scheduler

def forward_loss(nbatch):
    nobs  = nbatch['img_obs'][:, :obs_horizon].to(device)
    a_obs = nbatch['act_obs'][:, :obs_horizon].to(device)
    a_gt  = nbatch['act_pred'].to(device)
    B = a_obs.size(0)

    image_agent = nobs[:, : , 0]
    image_hand = nobs[:, : , 1]

    image_features_agent = visual_agentview_encoder.encode(image_agent.flatten(start_dim=0, end_dim=1)).reshape(*nobs.shape[:2], -1)

    image_features_hand = visual_eye_in_hand_encoder.encode(image_hand.flatten(start_dim=0, end_dim=1)).reshape(*nobs.shape[:2], -1)

    image_features = torch.cat([image_features_agent, image_features_hand, a_obs], dim=2).to(torch.float32)

    obs = image_features

    obs_cond = obs.flatten(start_dim=1)  # (B, H*obs_dim)

    noise = torch.randn_like(a_gt)
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
                               (B,), device=device).long()
    noisy_a = noise_scheduler.add_noise(a_gt, noise, timesteps)
    noise_pred = noise_prediction_net(noisy_a, timesteps, global_cond=obs_cond)
    return nn.functional.mse_loss(noise_pred, noise)

In [None]:
num_epochs = 3000

# EMA params
all_params = list(noise_prediction_net.parameters()) + list(visual_agentview_encoder.parameters()) + list(visual_eye_in_hand_encoder.parameters())
ema = EMAModel(parameters=all_params, power=0.75)

# optimizer
optimizer = torch.optim.AdamW(
    params=all_params,
    lr=1e-4,
    weight_decay=1e-6
)

# LR scheduler
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=(6000) * num_epochs
)


In [None]:
# load_checkpoint("../model_src/checkpoints/checkpoint_epoch050_loss0.0398.pth", visual_encoder, noise_prediction_net, ema, optimizer, lr_scheduler, map_location="cuda")

In [None]:
rotate_every = 5

from torch import nn
from diffusers import EMAModel, get_scheduler


train_hist, val_hist = [], []
for epoch_idx in range(num_epochs):

    loader = dataset.get_active_loader()     # use current loader

    epoch_loss_sum = 0.0
    noise_prediction_net.train()
    for nbatch in loader:
        nobs         = nbatch['img_obs'][:, :obs_horizon].to(device)
        action_obs = nbatch['act_obs'][:, :obs_horizon].to(device)  # (B, H, 2)
        action_pred  = nbatch['act_pred'].to(device).to(torch.float32)
        B            = nobs.size(0)

        image_agent = nobs[:, : , 0]
        image_hand = nobs[:, : , 1]

        image_features_agent = visual_agentview_encoder.encode(image_agent.flatten(start_dim=0, end_dim=1)).reshape(*nobs.shape[:2], -1)

        image_features_hand = visual_eye_in_hand_encoder.encode(image_hand.flatten(start_dim=0, end_dim=1)).reshape(*nobs.shape[:2], -1)

        image_features = torch.cat([image_features_agent, image_features_hand, action_obs], dim=2).to(torch.float32)

        del image_features_agent, image_features_hand

        obs_cond = image_features.flatten(1)

        noise     = torch.randn_like(action_pred, dtype=torch.float32).to(device)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device).long()
        noisy_act = noise_scheduler.add_noise(action_pred, noise, timesteps).to(torch.float32)
        noise_pred = noise_prediction_net(noisy_act, timesteps, global_cond=obs_cond)

        loss = nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        ema.step(all_params)

        epoch_loss_sum += loss.item()

    avg_train = epoch_loss_sum / len(loader)
    train_hist.append(avg_train)

    noise_prediction_net.eval()
    with torch.no_grad():
        val_sum = sum(forward_loss(b).item() for b in dataset.get_val_loader())
    avg_val = val_sum / len(dataset.get_val_loader())
    val_hist.append(avg_val)

    print(f"Epoch {epoch_idx+1:03d}/{num_epochs} | "
          f"train {avg_train:.6f} | val {avg_val:.6f}")

    if (epoch_idx + 1) % 100 == 0:
        save_checkpoint(
            epoch         = epoch_idx + 1,
            loss          = avg_val,
            visual_encoder= visual_agentview_encoder,
            noise_pred_net= noise_prediction_net,
            ema           = ema,
            optimizer     = optimizer,
            scheduler     = lr_scheduler,
            out_dir       = "checkpoints",
        )

    # rotate only after the epoch, when loader is finished
    if (epoch_idx + 1) % rotate_every == 0:
        del loader                          # drop reference
        dataset.rotate(use_random=True)                    # switch to next dataset

ema.copy_to(all_params)

# ckpt_path = "./checkpoints/checkpoint_epoch190_loss0.0316.pth"
#
# ckpt = torch.load(ckpt_path, map_location="cuda")
# visual_encoder.load_state_dict(ckpt["visual_encoder"])
# noise_prediction_net.load_state_dict(ckpt["noise_pred_net"])

In [None]:
save_final_models(visual_agentview_encoder, visual_eye_in_hand_encoder, noise_prediction_net,
                  "../models/robot_v8_tool_hang_agent_img_only_two_cameras")

In [None]:
load_final_models(visual_encoder, noise_prediction_net, "../models/robot_v8_tool_hang_agent_224_e40/model_final.pth")

In [None]:
val_sum = 0

with torch.no_grad():
        for batch in dataset.get_val_loader():
            val_sum += forward_loss(batch).item()

val_sum / len(dataset.get_val_loader())

# Inference

In [None]:
import robomimic
import robomimic.utils.obs_utils as ObsUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.file_utils as FileUtils
from robomimic.utils.vis_utils import depth_to_rgb
from robomimic.envs.env_base import EnvBase, EnvType

from tqdm import tqdm

import os

env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=data_path)
env_meta["env_kwargs"]["reward_shaping"] = True
env_meta["env_kwargs"]["reward_scale"]   = 1.0

dummy_spec = dict(
    obs=dict(
        low_dim=["robot0_eef_pos"],
        rgb=["agentview_image"]
        # rgb=["robot0_eye_in_hand_image"]
    ),
)

ObsUtils.initialize_obs_utils_with_obs_specs(obs_modality_specs=dummy_spec)

env = EnvUtils.create_env_from_metadata(env_meta=env_meta, render=True, render_offscreen=True, use_image_obs=True)

a = env.reset()

from collections import deque
obs_deque  = deque(maxlen=obs_horizon)
act_deque = deque(maxlen=obs_horizon)
rewards    = []
imgs       = []
step_idx   = 0

max_steps = 500
action_horizon  = 8

# ─── 6. Main rollout ──────────────────────────────────────────────────────────
obs = env.reset()
# wrap obs in same format as env.step
obs = obs if isinstance(obs, dict) else obs[0]
for i in range(obs_deque.maxlen):
    obs_deque.append(obs)
    act_deque.append(torch.zeros((1,7)))

pbar = tqdm(total=max_steps)
done = False

while not done and step_idx < max_steps:
    # 6.1 build the image & action history tensor
    img_np_agent = np.array([obs_deque[i]["agentview_image"] for i in range(obs_deque.maxlen)])

    img_t_agent   = torch.from_numpy(img_np_agent).float().to(device)

    img_np_hand = np.array([obs_deque[i]["robot0_eye_in_hand_image"] for i in range(obs_deque.maxlen)])

    img_t_hand   = torch.from_numpy(img_np_hand).float().permute(0, 3, 1, 2).to(device) / 256

    action_hist = act_deque.popleft().to(device)

    # 6.2 compute visual features + conditioning
    with torch.no_grad():
        image_features_agent = visual_agentview_encoder.encode(img_t_agent)

        image_features_hand = visual_eye_in_hand_encoder.encode(img_t_hand)

        image_features = torch.cat([image_features_agent, image_features_hand, action_hist], dim=-1)

        del image_features_agent, image_features_hand

        obs_cond = image_features.flatten(1)

        # 6.3 sample a future action sequence via diffusion
        B = 1
        pred_actions = torch.randn((B, pred_horizon, action_dim), device=device)
        noise_scheduler.set_timesteps(num_diffusion_iters)
        for t in noise_scheduler.timesteps:
            noise_pred    = noise_prediction_net(pred_actions, t, global_cond=obs_cond)
            out           = noise_scheduler.step(noise_pred, t, pred_actions)
            pred_actions  = out.prev_sample

    pred_actions = pred_actions.cpu().numpy()[0]        # (pred_horizon, 7)

    # 6.4 execute the next block of actions
    start = obs_horizon
    end   = start + action_horizon
    action_block = pred_actions[start:end]          # (5, 7)

    for act in action_block:
        obs, rew, done, info = env.step(act)
        obs = obs if isinstance(obs, dict) else obs[0]

        act_deque.append(torch.Tensor(act).unsqueeze(0))

        frame = env.render(mode="rgb_array", height=512, width=512)

        obs_deque.append(obs)

        rewards.append(rew)
        imgs.append(frame)

        step_idx += 1
        pbar.update(1)
        pbar.set_postfix(reward=float(rew))

        if done or step_idx >= max_steps:
            break

pbar.close()

# ─── 7. Wrap up ───────────────────────────────────────────────────────────────
print(f"Rollout finished: {step_idx} steps, total reward {sum(rewards):.2f}")

In [None]:
import imageio

video_path = "test_two_cams.mp4"
fps = 24

with imageio.get_writer(video_path, fps=fps, codec="libx264") as writer:
    for frame in imgs:
        writer.append_data(frame)

print(f"Saved video to {video_path}")

In [None]:
img = img_np[0]

In [None]:
plt.imshow(img.transpose(1,2,0))