# Imports

In [1]:
from robotics.gym_pusht.envs.pusht import PushTImageEnv
from robotics.model_src.dataset import PushTDataset
from robotics.model_src.diffusion_model import ConditionalUnet1D
from robotics.model_src.visual_encoder import CLIPVisualEncoder, CNNVisualEncoder

import numpy as np
import torch
import torch.nn as nn
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define Env

In [38]:
#@markdown ### **Env Demo**
#@markdown Standard Gym Env (0.21.0 API)

# 0. create env object
env = PushTImageEnv()

# 1. seed env for initial state.
# Seed 0-200 are used for the demonstration dataset.
env.seed(1000)

# 2. must reset before use
obs, info = env.reset()

# 3. 2D positional action space [0,512]
action = env.action_space.sample()

# 4. Standard gym step method
obs, reward, terminated, truncated, info = env.step(action)

# prints and explains each dimension of the observation and action vectors
with np.printoptions(precision=4, suppress=True, threshold=5):
    print("obs['image'].shape:", obs['image'].shape, "float32, [0,1]")
    print("obs['agent_pos'].shape:", obs['agent_pos'].shape, "float32, [0,512]")
    print("action.shape: ", action.shape, "float32, [0,512]")

obs['image'].shape: (3, 96, 96) float32, [0,1]
obs['agent_pos'].shape: (2,) float32, [0,512]
action.shape:  (2,) float32, [0,512]


# Load Data

In [3]:
import zarr

dataset_version = 1

ds = zarr.open(f"../data/demonstrations_snapshot_3.zarr", mode="r")

prev_actions      = ds["data"]["action"][:]        # shape (N, 2)
prev_images       = ds["data"]["img"][:]           # shape (N, 96, 96, 3)
prev_episode_ends = ds["meta"]["episode_ends"][:]          # shape (M,)

print("actions:",      prev_actions.shape, prev_actions.dtype)
print("images:",       prev_images.shape,  prev_images.dtype)
print("episode_ends:", prev_episode_ends.shape, prev_episode_ends.dtype)

actions: (25650, 2) float32
images: (25650, 96, 96, 3) float32
episode_ends: (206,) int64


In [4]:
pred_horizon = 16
obs_horizon = 1
action_horizon = 8

dataset_path = "/home/may33/Downloads/pusht/pusht_cchi_v7_replay.zarr"

# create dataset from file
dataset = PushTDataset(
    data_path=dataset_path,
    prediction_horizon=pred_horizon,
    obs_horizon=obs_horizon,
)

# create dataloader
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    num_workers=4,
    shuffle=True,
    # accelerate cpu-gpu transfer
    pin_memory=True,
    # don't kill worker process afte each epoch
    persistent_workers=True
)

# visualize data in batch
batch = next(iter(dataloader))
print("batch['image'].shape:", batch['img_obs'].shape)
print("batch['act_obs'].shape:", batch['act_obs'].shape)
print("batch['act_pred'].shape", batch['act_pred'].shape)

 12%|█▏        | 25/205 [02:42<19:27,  6.49s/it] 


KeyboardInterrupt: 

In [6]:
image = torch.Tensor(dataset[0]["img_obs"][None, :obs_horizon, :, :, :]).to(device)
act_obs = torch.Tensor(dataset[0]["act_obs"][None, :obs_horizon, :]).to(device)

# Define Models

In [7]:
# visual_encoder = CLIPVisualEncoder().to(device)

visual_encoder = CNNVisualEncoder().to(device)

vision_feature_dim = visual_encoder.get_output_shape()

action_observation_dim = 2

obs_dim = vision_feature_dim + action_observation_dim

action_dim = 2

noise_prediction_net = ConditionalUnet1D(
    input_dim=action_dim,
    global_cond_dim=obs_dim * obs_horizon,
).to(device)

number of parameters: 7.257856e+07


In [8]:
with torch.no_grad():
    image_features = visual_encoder.encode(image.flatten(start_dim=0, end_dim=1))

    image_features = image_features.reshape(*image.shape[:2], -1)

    obs = torch.cat([image_features, act_obs], dim=-1)

    noised_action = torch.randn((1, pred_horizon, action_dim)).to(device)

    timestep_tensor = torch.randint(0, 101, (1,), device=device)

    noise = noise_prediction_net(
        sample=noised_action,
        timestep=timestep_tensor,
        global_cond=obs.flatten(start_dim=1)
    )

    denoised_action = noised_action - noise


In [9]:
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler

num_diffusion_iters = 100

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_diffusion_iters,
    beta_schedule='squaredcos_cap_v2',
    clip_sample=True,
    prediction_type='epsilon'
)

# Train

In [15]:
num_epochs = 100

# EMA params
all_params = list(visual_encoder.parameters()) + list(noise_prediction_net.parameters())
ema = EMAModel(parameters=all_params, power=0.75)

# optimizer
optimizer = torch.optim.AdamW(
    params=all_params,
    lr=1e-4,
    weight_decay=1e-6
)

# LR scheduler
lr_scheduler = get_scheduler(
    name='cosine',
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(dataloader) * num_epochs
)

# train loop
for epoch_idx in range(num_epochs):
    epoch_loss_sum = 0.0

    for nbatch in dataloader:
        # prepare data
        nimage = nbatch['img_obs'][:, :obs_horizon].to(device)   # (B, H, 3,96,96)
        action_obs = nbatch['act_obs'][:, :obs_horizon].to(device) # (B, H, 2)
        action_pred = nbatch['act_pred'].to(device) # (B, P, 2)
        B = action_obs.size(0)

        # forward pass
        image_features = visual_encoder.forward(
            nimage.flatten(end_dim=1) # (B*H,3,96,96)
        ).reshape(*nimage.shape[:2], -1) # (B, H, D=512)

        obs_features = torch.cat([image_features, action_obs], dim=-1)
        obs_cond = obs_features.flatten(start_dim=1) # (B, H*obs_dim)

        noise = torch.randn_like(action_pred)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (B,), device=device).long()

        noisy_actions = noise_scheduler.add_noise(action_pred, noise, timesteps)
        noise_pred = noise_prediction_net(noisy_actions, timesteps, global_cond=obs_cond)

        loss = nn.functional.mse_loss(noise_pred, noise)

        # backward
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        lr_scheduler.step()
        ema.step(all_params)

        epoch_loss_sum += loss.item()

    avg_loss = epoch_loss_sum / len(dataloader)
    print(f"Epoch {epoch_idx + 1:3d}/{num_epochs} ─ average loss: {avg_loss:.6f}")

# copy EMA weights for inference
ema.copy_to(all_params)

Epoch   1/100 ─ average loss: 0.206289
Epoch   2/100 ─ average loss: 0.047087
Epoch   3/100 ─ average loss: 0.038137
Epoch   4/100 ─ average loss: 0.035086
Epoch   5/100 ─ average loss: 0.033083
Epoch   6/100 ─ average loss: 0.031619
Epoch   7/100 ─ average loss: 0.030763
Epoch   8/100 ─ average loss: 0.028147
Epoch   9/100 ─ average loss: 0.027110
Epoch  10/100 ─ average loss: 0.025061
Epoch  11/100 ─ average loss: 0.025310
Epoch  12/100 ─ average loss: 0.024298
Epoch  13/100 ─ average loss: 0.022232
Epoch  14/100 ─ average loss: 0.022429
Epoch  15/100 ─ average loss: 0.021396
Epoch  16/100 ─ average loss: 0.021284
Epoch  17/100 ─ average loss: 0.020716
Epoch  18/100 ─ average loss: 0.019403
Epoch  19/100 ─ average loss: 0.019365
Epoch  20/100 ─ average loss: 0.019337
Epoch  21/100 ─ average loss: 0.018289
Epoch  22/100 ─ average loss: 0.017934
Epoch  23/100 ─ average loss: 0.017481
Epoch  24/100 ─ average loss: 0.017529
Epoch  25/100 ─ average loss: 0.016947
Epoch  26/100 ─ average l

# Save / Load model.021677
Epoch  19/100 ─ average loss: 0.021065
Epoch  20/100 ─ average loss: 0.019704
Epoch  21/100 ─ average loss: 0.020990
Epoch  22/100 ─ average loss: 0.019516
Epoch  23/100 ─ average loss: 0.019944
Epoch  24/100 ─ average loss: 0.019400

In [13]:
from pathlib import Path


def save_final_models(visual_encoder, noise_pred_net, out_dir):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    torch.save(
        {
            "visual_encoder": visual_encoder.state_dict(),
            "noise_pred_net": noise_pred_net.state_dict(),
        },
        out_dir / "model_final.pth",
    )
    print(f"Saved to {out_dir / 'models.pth'}")

In [66]:
save_final_models(visual_encoder, noise_prediction_net,
                  "../models/v6_success_both_cnn_actions_h1")

Saved to ../models/v6_success_both_cnn_actions_h1/models.pth


In [11]:
def load_final_models(visual_encoder, noise_pred_net, ckpt_path, device="cuda"):
    ckpt_path = Path(ckpt_path)
    state = torch.load(ckpt_path, map_location=device)

    visual_encoder.load_state_dict(state["visual_encoder"], strict=True)
    noise_pred_net.load_state_dict(state["noise_pred_net"], strict=True)

    visual_encoder.to(device).eval()
    noise_pred_net.to(device).eval()
    print(f"Loaded weights from {ckpt_path}")


In [15]:
load_final_models(visual_encoder, noise_prediction_net, "../models/v6_success_both_cnn_actions_h1/model_final.pth")

Loaded weights from ../models/v6_success_both_cnn_actions_h1/model_final.pth


# Inference

In [35]:
import collections
from robotics.model_src.dataset import normalize_data

# ────────────────────────────── Inference ────────────────────────────────────
vis_diffusion = True
last_step_wait = 10
interval = 1
interval_c = 0

max_steps = 200
env = PushTImageEnv(render_size_vis=512)

env.seed(100000)                        # unseen initial states

obs, info = env.reset()

# obs queue
obs_deque = collections.deque([obs] * obs_horizon, maxlen=obs_horizon)
init_action = obs['agent_pos'].copy()   # shape (2,)
act_deque  = collections.deque([init_action] * obs_horizon, maxlen=obs_horizon)

imgs = [env.render(mode='rgb_array')]
rewards  = []
done     = False
step_idx = 0

with tqdm(total=max_steps, desc="Eval PushTImageEnv") as pbar:
    while not done:
        B = 1

        # stack last obs_horizon frames
        images_hist = np.stack([x['image'] for x in obs_deque])    # (H,3,96,96)
        actions_hist = np.stack(act_deque)                         # (H,2)

        # normalize low-dim part to [-1,1]
        actions_hist = normalize_data(actions_hist, scale=512)
        images_hist = images_hist  # images already in [0,1]

        # device transfer
        images_hist = torch.from_numpy(images_hist).to(device, dtype=torch.float32)
        actions_hist = torch.from_numpy(actions_hist).to(device, dtype=torch.float32)

        # infer action sequence via diffusion
        with torch.no_grad():
            image_features = visual_encoder.forward(images_hist)             # (H,512)
            obs_features = torch.cat([image_features, actions_hist], dim=-1)

            obs_cond = obs_features.unsqueeze(0).flatten(start_dim=1)        # (B, H*(512+2))

            # start from Gaussian noise
            pred_actions = torch.randn((B, pred_horizon, action_dim), device=device)

            noise_scheduler.set_timesteps(num_diffusion_iters)
            for k in noise_scheduler.timesteps:
                noise_pred = noise_prediction_net(
                    sample=pred_actions, timestep=k, global_cond=obs_cond
                )
                pred_actions = noise_scheduler.step(
                    model_output=noise_pred, timestep=k, sample=pred_actions
                ).prev_sample

                if vis_diffusion and interval_c % interval == 0:
                    img_traj = env.render_actions(pred_actions[0].cpu().numpy())
                    imgs.append(img_traj)

            if vis_diffusion and interval_c % interval == 0:
                for i in range(last_step_wait):
                    imgs.append(img_traj)

            interval_c += 1

        # denormalize back to [0,512] range
        pred_actions = pred_actions.cpu().numpy()[0]                       # (P,2)
        action_exec = normalize_data(pred_actions, scale=1/512)           # → [0,512]

        # take action_horizon actions starting at obs_horizon-1
        start = obs_horizon - 1
        end   = start + action_horizon
        action_block = action_exec[start:end]                         # (H_a,2)

        # execute the planned actions
        for act in action_block:
            obs, reward, done, _, info = env.step(act)

            obs_deque.append(obs)
            act_deque.append(act)

            rewards.append(reward)
            imgs.append(env.render(mode='rgb_array'))

            step_idx += 1
            pbar.update(1)
            pbar.set_postfix(reward=float(reward))

            if step_idx >= max_steps:
                done = True
            if done:
                break

print('Score: ', max(rewards))


Eval PushTImageEnv: 100%|██████████| 200/200 [00:06<00:00, 30.72it/s, reward=0.702] 

Score:  0.9393150665784007





In [34]:
from IPython.display import Video
from skvideo.io import vwrite
vwrite("../results/vis.mp4", imgs)
Video('vis.mp4', embed=True, width=1024, height=1024)
