<a href="https://colab.research.google.com/github/SBprjcts/DiffusionPolicyTraining/blob/main/PushT_TrainingNotebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Install lerobot with PushT simulator support
%pip install 'lerobot[pusht]' -q

# After installation completes, you may need to restart the runtime
print("Installation complete! If you see dependency errors, go to Runtime → Restart runtime")

Note: you may need to restart the kernel to use updated packages.
Installation complete! If you see dependency errors, go to Runtime → Restart runtime



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip
ERROR: Invalid requirement: "'lerobot[pusht]'": Expected package name at the start of dependency specifier
    'lerobot[pusht]'
    ^


In [None]:
!git clone https://github.com/huggingface/lerobot.git
%cd lerobot


UsageError: Line magic function `%git` not found.


# Step 1: Collect Your Own PushT Data (Run Locally Only)

**This section must be run on your local machine** (not Colab) because it opens a pygame window for mouse teleoperation.

**How it works:**
1. A PushT simulator window will open
2. Click near the circular agent and drag with your mouse to push the T-block into the target
3. Each episode runs for up to 300 steps (30 seconds at 10 fps)
4. Press `Q` or close the window to stop early within an episode
5. After all episodes are recorded, the dataset is saved locally

**After collecting data**, zip the output folder and upload it to Colab for training (Step 2).

In [None]:
"""
DATA COLLECTION — Run LOCALLY only (needs a pygame display window).
Adapted from diffusion_policy demo_pusht.py, saves to LeRobotDataset format.

Controls:
  - Hover mouse near the blue circle to grab the agent
  - Push the T-block into the green target area
  - Space: hold to pause
  - R: retry current episode
  - Q: quit collection
"""

import shutil
from pathlib import Path

import gymnasium as gym
import gym_pusht  # registers PushT-v0 with gymnasium
import numpy as np
import pygame
from PIL import Image

from lerobot.datasets.lerobot_dataset import LeRobotDataset

# ── Configuration ──────────────────────────────────────────────
NUM_EPISODES = 200           # Number of episodes to collect
MAX_STEPS_PER_EPISODE = 300  # 300 steps = 30s at 10 fps
FPS = 10
REPO_ID = "custom_pusht"
DATASET_ROOT = Path("./custom_pusht_data") / REPO_ID

# Clean previous data if it exists
if DATASET_ROOT.exists():
    shutil.rmtree(DATASET_ROOT)

# ── Create LeRobot dataset on disk ────────────────────────────
features = {
    "observation.image": {
        "dtype": "video",
        "shape": (384, 384, 3),
        "names": ["height", "width", "channels"],
    },
    "observation.state": {
        "dtype": "float32",
        "shape": (2,),
        "names": ["x", "y"],
    },
    "action": {
        "dtype": "float32",
        "shape": (2,),
        "names": ["x", "y"],
    },
}

dataset = LeRobotDataset.create(
    repo_id=REPO_ID,
    fps=FPS,
    features=features,
    root=DATASET_ROOT,
    robot_type="pusht_sim",
    use_videos=True,
    image_writer_processes=0,
    image_writer_threads=0,
)

# ── Open PushT environment ────────────────────────────────────
env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    render_mode="human",
    max_episode_steps=MAX_STEPS_PER_EPISODE,
)

clock = pygame.time.Clock()
episode_count = 0

print(f"Collecting up to {NUM_EPISODES} episodes.")
print("Hover near the blue circle to grab. Push T onto target.")
print("Space=pause, R=retry, Q=quit\n")

while episode_count < NUM_EPISODES:
    obs, info = env.reset()
    teleop = env.unwrapped.teleop_agent()

    episode_frames = []  # buffer frames in case of retry
    retry = False
    pause = False
    done = False

    pygame.display.set_caption(f"Episode {episode_count + 1}/{NUM_EPISODES}")

    while not done:
        # ── Handle keyboard events ──
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                env.close()
                dataset.finalize()
                print(f"\nQuit early. Saved {episode_count} episodes to: {DATASET_ROOT}")
                raise SystemExit
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_q:
                    env.close()
                    dataset.finalize()
                    print(f"\nQuit. Saved {episode_count} episodes to: {DATASET_ROOT}")
                    raise SystemExit
                if event.key == pygame.K_r:
                    retry = True
                if event.key == pygame.K_SPACE:
                    pause = True
            if event.type == pygame.KEYUP:
                if event.key == pygame.K_SPACE:
                    pause = False

        if retry:
            break
        if pause:
            clock.tick(FPS)
            continue

        # ── Teleoperate ──
        action = teleop.act(obs)
        if action is None:
            # Mouse not close enough to agent — idle
            clock.tick(FPS)
            continue

        action = np.array(action, dtype=np.float32)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # Buffer the frame
        episode_frames.append({
            "observation.image": Image.fromarray(obs["pixels"]),
            "observation.state": np.array(obs["agent_pos"], dtype=np.float32),
            "action": action,
            "task": "Push the T-block onto the target.",
        })

        obs = next_obs
        done = terminated or truncated
        clock.tick(FPS)

    if retry:
        print(f"  Retrying episode {episode_count + 1}")
        continue

    # ── Save episode ──
    for frame in episode_frames:
        dataset.add_frame(frame)
    dataset.save_episode()
    episode_count += 1
    print(f"Episode {episode_count}/{NUM_EPISODES} saved ({len(episode_frames)} steps)")

env.close()
dataset.finalize()
print(f"\nDone! Dataset saved to: {DATASET_ROOT}")

  from pkg_resources import resource_stream, resource_exists
  from .autonotebook import tqdm as notebook_tqdm


Collecting up to 200 episodes.
Hover near the blue circle to grab. Push T onto target.
Space=pause, R=retry, Q=quit



  logger.warn(
  gym.logger.warn("Casting input x to numpy array.")
  logger.warn(f"{pre} is not within the observation space.")


AttributeError: 'PushTEnv' object has no attribute 'screen'

: 

# Step 2: Upload Your Custom Dataset to Colab

After collecting data locally, **zip** the dataset folder and upload it to Colab:

1. **Locally**, run: `cd custom_pusht_data && zip -r custom_pusht.zip custom_pusht/`
2. Upload `custom_pusht.zip` to Colab (drag into the file browser, or use the cell below)
3. Run the cell below to unzip it

In [None]:
# Upload and unzip the custom dataset (run on Colab)
from google.colab import files
import zipfile, os

uploaded = files.upload()  # select custom_pusht.zip
zip_name = list(uploaded.keys())[0]

with zipfile.ZipFile(zip_name, 'r') as z:
    z.extractall("custom_pusht_data")

print("Extracted to custom_pusht_data/custom_pusht")
print("Contents:", os.listdir("custom_pusht_data/custom_pusht"))

In [None]:
# Train diffusion policy on your custom dataset
# Change DATASET_PATH to "lerobot/pusht" to use the default HuggingFace dataset instead

import time
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors

# ── Point to your custom dataset ──────────────────────────────
DATASET_PATH = "custom_pusht"                     # repo_id used during collection
DATASET_ROOT = Path("./custom_pusht_data")         # local root folder

output_directory = Path("outputs/train/example_pusht_diffusion")
checkpoint_directory = Path("outputs/train/checkpoints")
output_directory.mkdir(parents=True, exist_ok=True)
checkpoint_directory.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda")
training_steps = 100000
log_freq = 250
checkpoint_freq = 5000          # save a checkpoint every N steps
time_limit_seconds = 3 * 3600   # stop after 3 hours

# Load dataset metadata (features + stats)
dataset_metadata = LeRobotDatasetMetadata(DATASET_PATH, root=DATASET_ROOT / DATASET_PATH)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
policy.train()
policy.to(device)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)

delta_timestamps = {
    "observation.image": [-0.1, 0.0],
    "observation.state": [-0.1, 0.0],
    "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}

dataset = LeRobotDataset(
    DATASET_PATH,
    root=DATASET_ROOT / DATASET_PATH,
    delta_timestamps=delta_timestamps,
    video_backend="pyav",
)

optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device.type != "cpu",
    drop_last=True,
)

# Training loop with checkpointing and time limit
step = 0
done = False
start_time = time.time()

while not done:
    for batch in dataloader:
        batch = preprocessor(batch)
        loss, _ = policy.forward(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % log_freq == 0:
            elapsed = (time.time() - start_time) / 60
            print(f"step: {step} loss: {loss.item():.3f} ({elapsed:.1f} min elapsed)")

        step += 1

        # Save checkpoint periodically
        if step % checkpoint_freq == 0:
            ckpt_path = checkpoint_directory / f"step_{step}"
            policy.save_pretrained(ckpt_path)
            preprocessor.save_pretrained(ckpt_path)
            postprocessor.save_pretrained(ckpt_path)
            print(f"  ✓ Checkpoint saved at step {step} → {ckpt_path}")

        # Stop if we hit the step limit or time limit
        elapsed_seconds = time.time() - start_time
        if step >= training_steps or elapsed_seconds >= time_limit_seconds:
            if elapsed_seconds >= time_limit_seconds:
                print(f"\n Time limit reached ({elapsed_seconds/3600:.1f}h). Stopping at step {step}.")
            done = True
            break

# Final save
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
elapsed_total = (time.time() - start_time) / 60
print(f"\nTraining complete! {step} steps in {elapsed_total:.1f} min.")
print(f"Final model saved to {output_directory}")
print(f"Checkpoints saved in {checkpoint_directory}")

In [6]:
!lerobot-eval \
    --policy.path=outputs/train/example_pusht_diffusion \
    --env.type=pusht \
    --eval.batch_size=10 \
    --eval.n_episodes=10 \
    --policy.device=cuda


2026-01-21 05:26:31.963522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768973191.994466   14209 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768973192.005011   14209 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768973192.030511   14209 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768973192.030547   14209 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768973192.030554   14209 computation_placer.cc:177] computation placer alr