# Imports

In [None]:
from robotics.gym_pusht.envs.pusht import PushTImageEnv
from robotics.model_src.dataset import PushTDataset
from robotics.model_src.diffusion_model import ConditionalUnet1D
from robotics.model_src.visual_encoder import CLIPVisualEncoder, CNNVisualEncoder

import numpy as np
import torch
import torch.nn as nn
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

In [None]:
%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define Env

In [None]:
#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.21.0 API)

# 0. create env object
env = PushTImageEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)

# 2. must reset before use
obs, info = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("obs['image'].shape:", obs['image'].shape, "float32, [0,1]")
    print("obs['agent_pos'].shape:", obs['agent_pos'].shape, "float32, [0,512]")
    print("action.shape: ", action.shape, "float32, [0,512]")

# Load Data

In [None]:
import zarr

dataset_version = 1

ds = zarr.open(f"../data/demonstrations_snapshot_3.zarr", mode="r")

prev_actions = ds["data"]["action"][:]  # shape (N, 2)
prev_images = ds["data"]["img"][:]  # shape (N, 96, 96, 3)
prev_episode_ends = ds["meta"]["episode_ends"][:]  # shape (M,)

print("actions:", prev_actions.shape, prev_actions.dtype)
print("images:", prev_images.shape, prev_images.dtype)
print("episode_ends:", prev_episode_ends.shape, prev_episode_ends.dtype)

In [None]:
pred_horizon = 16
obs_horizon = 1
action_horizon = 8

dataset_path = "/home/may33/Downloads/pusht/pusht_cchi_v7_replay.zarr"

# create dataset from file
dataset = PushTDataset(
    data_path=dataset_path,
    prediction_horizon=pred_horizon,
    obs_horizon=obs_horizon,
)

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['img_obs'].shape)
print("batch['act_obs'].shape:", batch['act_obs'].shape)
print("batch['act_pred'].shape", batch['act_pred'].shape)

## Dataset space coverage

In [None]:
from matplotlib import pyplot as plt

actions = dataset.actions_data_transformed

x = actions[:, 0] * 512
y = actions[:, 1] * 512

plt.scatter(x, y, s=0.2, c="purple")

plt.savefig("pushT_state_coverage.png", dpi=300, bbox_inches="tight")

In [None]:
image = torch.Tensor(dataset[0]["img_obs"][None, :obs_horizon, :, :, :]).to(device)
act_obs = torch.Tensor(dataset[0]["act_obs"][None, :obs_horizon, :]).to(device)

# Define Models

In [None]:
# visual_encoder = CLIPVisualEncoder().to(device)

visual_encoder = CNNVisualEncoder().to(device)

vision_feature_dim = visual_encoder.get_output_shape()

action_observation_dim = 2

obs_dim = vision_feature_dim + action_observation_dim

action_dim = 2

noise_prediction_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim * obs_horizon,
).to(device)

In [None]:
with torch.no_grad():
    image_features = visual_encoder.encode(image.flatten(start_dim=0, end_dim=1))

    image_features = image_features.reshape(*image.shape[:2], -1)

    obs = torch.cat([image_features, act_obs], dim=-1)

    noised_action = torch.randn((1, pred_horizon, action_dim)).to(device)

    timestep_tensor = torch.randint(0, 101, (1,), device=device)

    noise = noise_prediction_net(
        sample=noised_action,
        timestep=timestep_tensor,
        global_cond=obs.flatten(start_dim=1)
    )

    denoised_action = noised_action - noise


In [None]:
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

num_diffusion_iters = 100

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    clip_sample=True,
    prediction_type='epsilon'
)

# Train

In [None]:
num_epochs = 100

# EMA params
all_params = list(visual_encoder.parameters()) + list(noise_prediction_net.parameters())
ema = EMAModel(parameters=all_params, power=0.75)

# optimizer
optimizer = torch.optim.AdamW(
    params=all_params,
    lr=1e-4,
    weight_decay=1e-6
)

# LR scheduler
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

# train loop
for epoch_idx in range(num_epochs):
    epoch_loss_sum = 0.0

    for nbatch in dataloader:
        # prepare data
        nimage = nbatch['img_obs'][:, :obs_horizon].to(device)  # (B, H, 3,96,96)
        action_obs = nbatch['act_obs'][:, :obs_horizon].to(device)  # (B, H, 2)
        action_pred = nbatch['act_pred'].to(device)  # (B, P, 2)
        B = action_obs.size(0)

        # forward pass
        image_features = visual_encoder.forward(
            nimage.flatten(end_dim=1)  # (B*H,3,96,96)
        ).reshape(*nimage.shape[:2], -1)  # (B, H, D=512)

        obs_features = torch.cat([image_features, action_obs], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1)  # (B, H*obs_dim)

        noise = torch.randn_like(action_pred)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device).long()

        noisy_actions = noise_scheduler.add_noise(action_pred, noise, timesteps)
        noise_pred = noise_prediction_net(noisy_actions, timesteps, global_cond=obs_cond)

        loss = nn.functional.mse_loss(noise_pred, noise)

        # backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        ema.step(all_params)

        epoch_loss_sum += loss.item()

    avg_loss = epoch_loss_sum / len(dataloader)
    print(f"Epoch {epoch_idx + 1:3d}/{num_epochs} ─ average loss: {avg_loss:.6f}")

# copy EMA weights for inference
ema.copy_to(all_params)

# Save / Load model.021677
Epoch  19/100 ─ average loss: 0.021065
Epoch  20/100 ─ average loss: 0.019704
Epoch  21/100 ─ average loss: 0.020990
Epoch  22/100 ─ average loss: 0.019516
Epoch  23/100 ─ average loss: 0.019944
Epoch  24/100 ─ average loss: 0.019400

In [None]:
from pathlib import Path


def save_final_models(visual_encoder, noise_pred_net, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "visual_encoder": visual_encoder.state_dict(),
            "noise_pred_net": noise_pred_net.state_dict(),
        },
        out_dir / "model_final.pth",
    )
    print(f"Saved to {out_dir / 'models.pth'}")

In [None]:
save_final_models(visual_encoder, noise_prediction_net,
                  "../models/v6_success_both_cnn_actions_h1")

In [None]:
def load_final_models(visual_encoder, noise_pred_net, ckpt_path, device="cuda"):
    ckpt_path = Path(ckpt_path)
    state = torch.load(ckpt_path, map_location=device)

    visual_encoder.load_state_dict(state["visual_encoder"], strict=True)
    noise_pred_net.load_state_dict(state["noise_pred_net"], strict=True)

    visual_encoder.to(device).eval()
    noise_pred_net.to(device).eval()
    print(f"Loaded weights from {ckpt_path}")


In [None]:
load_final_models(visual_encoder, noise_prediction_net, "../models/v6_success_both_cnn_actions_h1/model_final.pth")

# Inference

In [None]:
import time
import torch
import numpy as np
import collections
from tqdm import tqdm
from robotics.model_src.dataset import normalize_data

# ─────────────── Config ───────────────
n_val = 20
max_steps = 200

pred_horizon = 16
action_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# List of action_horizons to evaluate
action_horizons = [1, 2, 4, 8, 16]
# ─────────────── Evaluation ───────────────
results = []

for action_horizon in action_horizons:
    print(f"\nEvaluating for action_horizon = {action_horizon}...")
    episodes = []
    episode_times = []

    for episode_i in range(n_val):
        env = PushTImageEnv(render_size_vis=512)
        env.seed(100000 + episode_i)
        obs, info = env.reset()

        obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
        init_action = obs['agent_pos'].copy()
        act_deque = collections.deque([init_action] * obs_horizon, maxlen=obs_horizon)

        done = False
        rewards = []
        step_idx = 0

        episode_start_time = time.perf_counter()

        with tqdm(total=max_steps, desc=f"Episode {episode_i + 1}/{n_val} | Horizon {action_horizon}") as pbar:
            while not done:
                # Stack input
                images_hist = np.stack([x['image'] for x in obs_deque])
                actions_hist = np.stack(act_deque)
                actions_hist = normalize_data(actions_hist, scale=512)

                images_hist = torch.from_numpy(images_hist).to(device, dtype=torch.float32)
                actions_hist = torch.from_numpy(actions_hist).to(device, dtype=torch.float32)

                with torch.no_grad():
                    image_features = visual_encoder(images_hist)
                    obs_features = torch.cat([image_features, actions_hist], dim=-1)
                    obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

                    pred_actions = torch.randn((1, pred_horizon, action_dim), device=device)
                    noise_scheduler.set_timesteps(num_diffusion_iters)
                    for k in noise_scheduler.timesteps:
                        noise_pred = noise_prediction_net(
                            sample=pred_actions,
                            timestep=k,
                            global_cond=obs_cond
                        )
                        pred_actions = noise_scheduler.step(
                            model_output=noise_pred,
                            timestep=k,
                            sample=pred_actions
                        ).prev_sample

                    pred_actions = pred_actions.cpu().numpy()[0]
                    action_exec = normalize_data(pred_actions, scale=1 / 512)

                # Execute action block
                start = obs_horizon - 1
                end = start + action_horizon
                action_block = action_exec[start:end]

                for act in action_block:
                    obs, reward, done, _, info = env.step(act)
                    obs_deque.append(obs)
                    act_deque.append(act)

                    rewards.append(reward)
                    step_idx += 1
                    pbar.update(1)
                    pbar.set_postfix(reward=float(reward))

                    if step_idx >= max_steps:
                        done = True
                        break

        episode_duration = time.perf_counter() - episode_start_time
        print(f"Score: {max(rewards)} | Episode Time: {episode_duration:.2f} s")

        episodes.append(max(rewards))
        episode_times.append(episode_duration)

    results.append({
        "horizon": action_horizon,
        "avg_score": np.mean(episodes),
        "std_score": np.std(episodes),
        "avg_episode_time": np.mean(episode_times),
        "std_episode_time": np.std(episode_times),
    })

# ─────────────── Summary ───────────────
print("\n=== Evaluation Summary ===")
print(f"{'Horizon':>8} | {'Score μ':>8} | {'Score σ':>8} | {'Time μ (s)':>10} | {'Time σ':>10}")
print("-" * 58)
for res in results:
    print(
        f"{res['horizon']:>8} | {res['avg_score']:>8.2f} | {res['std_score']:>8.2f} | {res['avg_episode_time']:>10.4f} | {res['std_episode_time']:>10.4f}")


In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Horizon categories
horizons = [1, 2, 4, 8, 16]
x_pos = list(range(len(horizons)))

# Score data
score_mu = [0.78, 0.82, 0.77, 0.79, 0.81]
score_sigma = [0.29, 0.29, 0.28, 0.27, 0.25]

# Time data
time_mu = [34.8187, 21.4516, 11.3340, 5.5704, 3.3637]

box_width = 0.35
ymin_clip, ymax_clip = 0.4, 1.0
top_line_y = 0.995  # new top line position

fig, ax_score = plt.subplots(figsize=(7, 4))

# Draw rectangles, thicker mean and top lines
for idx, (mu, sigma) in enumerate(zip(score_mu, score_sigma)):
    rect = Rectangle((idx - box_width / 2, mu - sigma),
                     box_width, 2 * sigma,
                     fill=False, linewidth=1, edgecolor='black')
    ax_score.add_patch(rect)

    # Thicker mean line
    ax_score.hlines(mu, idx - box_width / 2, idx + box_width / 2,
                    linewidth=1, color='black')

    # Thicker top line at y = 0.99
    ax_score.hlines(top_line_y, idx - box_width / 2, idx + box_width / 2,
                    linewidth=1, color='black')

# Configure axes
ax_score.set_ylabel("Score", fontsize=16)
ax_score.set_ylim(ymin_clip, ymax_clip)
ax_score.set_xticks(x_pos)
ax_score.set_xticklabels([str(h) for h in horizons])
ax_score.set_xlabel("Horizon (categorical)", fontsize=16)

# Time axis
ax_time = ax_score.twinx()
ax_time.plot(x_pos, time_mu, linestyle='--', marker='o')
ax_time.set_ylabel("Time (s)", fontsize=16)

plt.tight_layout()

plt.savefig("pusht_eval.png", dpi=300)

plt.show()


In [None]:
import time
import torch
import numpy as np
import collections
from tqdm import tqdm
from robotics.model_src.dataset import normalize_data

# ─────────────── Config ───────────────
max_steps = 500

pred_horizon = 16
action_dim = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ─────────────── Evaluation ───────────────
results = []

episodes = []
episode_times = []
imgs = []

env = PushTImageEnv(render_size_vis=512)
# env.seed(100000 + episode_i)
obs, info = env.reset()

obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
init_action = obs['agent_pos'].copy()
act_deque = collections.deque([init_action] * obs_horizon, maxlen=obs_horizon)

done = False
rewards = []
step_idx = 0

episode_start_time = time.perf_counter()

with tqdm(total=max_steps) as pbar:
    while not done:
        # Stack input
        images_hist = np.stack([x['image'] for x in obs_deque])
        actions_hist = np.stack(act_deque)
        actions_hist = normalize_data(actions_hist, scale=512)

        images_hist = torch.from_numpy(images_hist).to(device, dtype=torch.float32)
        actions_hist = torch.from_numpy(actions_hist).to(device, dtype=torch.float32)

        with torch.no_grad():
            image_features = visual_encoder(images_hist)
            obs_features = torch.cat([image_features, actions_hist], dim=-1)
            obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)

            pred_actions = torch.randn((1, pred_horizon, action_dim), device=device)
            noise_scheduler.set_timesteps(num_diffusion_iters)
            for k in noise_scheduler.timesteps:
                noise_pred = noise_prediction_net(
                    sample=pred_actions,
                    timestep=k,
                    global_cond=obs_cond
                )
                pred_actions = noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=k,
                    sample=pred_actions
                ).prev_sample

            pred_actions = pred_actions.cpu().numpy()[0]
            action_exec = normalize_data(pred_actions, scale=1 / 512)

        # Execute action block
        start = obs_horizon - 1
        end = start + action_horizon
        action_block = action_exec[start:end]

        for act in action_block:
            obs, reward, done, _, info = env.step(act)
            obs_deque.append(obs)
            act_deque.append(act)

            imgs.append(env.render(mode='rgb_array'))

            rewards.append(reward)
            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=float(reward))

            if step_idx >= max_steps:
                done = True
                break

episode_duration = time.perf_counter() - episode_start_time
print(f"Score: {max(rewards)} | Episode Time: {episode_duration:.2f} s")

episodes.append(max(rewards))
episode_times.append(episode_duration)


In [None]:
from IPython.display import Video
from skvideo.io import vwrite

vwrite("../results/vis_5.mp4", imgs)
