In [10]:
%pip install timm einops
%pip install robosuite
%pip install tqdm

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.insert(0, "..")

In [3]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.nn import functional as F

from app.vjepa_droid.transforms import make_transforms
from utils.mpc_utils import (
    compute_new_pose,
    poses_to_diff
)

In [4]:
import os
cache_dir = '/workspace/assets/vjepa2/torch_cache'
os.makedirs(cache_dir, exist_ok=True)
torch.hub.set_dir(cache_dir)

In [5]:
def forward_target(c, normalize_reps=True):
    B, C, T, H, W = c.size()
    c = c.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1)
    h = encoder(c)
    h = h.view(B, T, -1, h.size(-1)).flatten(1, 2)
    if normalize_reps:
        h = F.layer_norm(h, (h.size(-1),))
    return h


def forward_actions(z, nsamples, grid_size=0.075, normalize_reps=True, action_repeat=1):

    def make_action_grid(grid_size=grid_size):
        action_samples = []
        for da in np.linspace(-grid_size, grid_size, nsamples):
            for db in np.linspace(-grid_size, grid_size, nsamples):
                for dc in np.linspace(-grid_size, grid_size, nsamples):
                    action_samples += [torch.tensor([da, db, dc, 0, 0, 0, 0], device=z.device, dtype=z.dtype)]
        return torch.stack(action_samples, dim=0).unsqueeze(1)

    # Sample grid of actions
    action_samples = make_action_grid()
    print(f"Sampled grid of actions; num actions = {len(action_samples)}")

    def step_predictor(_z, _a, _s):
        _z = predictor(_z, _a, _s)[:, -tokens_per_frame:]
        if normalize_reps:
            _z = F.layer_norm(_z, (_z.size(-1),))
        _s = compute_new_pose(_s[:, -1:], _a[:, -1:])
        return _z, _s

    # Context frame rep and context pose
    z_hat = z[:, :tokens_per_frame].repeat(int(nsamples**3), 1, 1)  # [S, N, D]
    s_hat = states[:, :1].repeat((int(nsamples**3), 1, 1))  # [S, 1, 7]
    a_hat = action_samples  # [S, 1, 7]

    for _ in range(action_repeat):
        _z, _s = step_predictor(z_hat, a_hat, s_hat)
        z_hat = torch.cat([z_hat, _z], dim=1)
        s_hat = torch.cat([s_hat, _s], dim=1)
        a_hat = torch.cat([a_hat, action_samples], dim=1)

    return z_hat, s_hat, a_hat

def loss_fn(z, h):
    z, h = z[:, -tokens_per_frame:], h[:, -tokens_per_frame:]
    loss = torch.abs(z - h)  # [B, N, D]
    loss = torch.mean(loss, dim=[1, 2])
    return loss.tolist()

In [6]:
device = "cuda"

# Initialize VJEPA 2-AC model
encoder, predictor = torch.hub.load("facebookresearch/vjepa2", "vjepa2_ac_vit_giant")
encoder.to(device)
predictor.to(device)

# Initialize transform
crop_size = 256
tokens_per_frame = int((crop_size // encoder.patch_size) ** 2)
transform = make_transforms(
    random_horizontal_flip=False,
    random_resize_aspect_ratio=(1., 1.),
    random_resize_scale=(1., 1.),
    reprob=0.,
    auto_augment=False,
    motion_shift=False,
    crop_size=crop_size,
)

from utils.world_model_wrapper import WorldModel

world_model = WorldModel(
    encoder=encoder,
    predictor=predictor,
    tokens_per_frame=tokens_per_frame,
    transform=transform,
    # Doing very few CEM iterations with very few samples just to run efficiently on CPU...
    # ... increase cem_steps and samples for more accurate optimization of energy landscape
    mpc_args={
        "rollout": 2,
        "samples": 25,
        "topk": 10,
        "cem_steps": 2,
        "momentum_mean": 0.15,
        "momentum_mean_gripper": 0.15,
        "momentum_std": 0.75,
        "momentum_std_gripper": 0.15,
        "maxnorm": 0.075,
        "verbose": True
    },
    normalize_reps=True,
    device=device
)

Using cache found in /workspace/assets/vjepa2/torch_cache/facebookresearch_vjepa2_main


In [None]:
import robosuite as suite
from scipy.spatial.transform import Rotation as R
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 100
from tqdm.notebook import tqdm

def get_state(obs):
    '''
    obsからV-JEPA2 ACのstate(7D vector)を計算
    '''
    # 1. 位置 (Cartesian position): 3次元
    eef_pos = obs["robot0_eef_pos"]
    
    # 2. 姿勢 (Orientation) を Euler angles に変換: 3次元
    # robosuiteのクォータニオンは通常 [x, y, z, w]
    quat = obs["robot0_eef_quat"]
    r = R.from_quat(quat)
    # 外的不変(extrinsic)オイラー角 (xyz) を取得
    euler = r.as_euler('xyz', degrees=False)
    
    # 3. グリッパーの状態 (Gripper state): 1次元
    # robot0_gripper_qpos は通常 2次元（左右 of 指）なので平均や片方を使用
    gripper_state = np.mean(obs["robot0_gripper_qpos"], keepdims=True)
    
    # すべてを結合して 7次元ベクトル (sk) を作成
    state = np.concatenate([eef_pos, euler, gripper_state]).astype(np.float32)
    
    return state

# 1. 環境のセットアップ（offscreenレンダリングをTrueにする）
env = suite.make(
    env_name="Lift",
    robots="Panda",
    has_renderer=False,          # ポップアップウィンドウを表示しない
    has_offscreen_renderer=True, # 裏側で描画する
    control_freq=20,
    use_camera_obs=True,         # カメラ画像を取得する
    camera_names="agentview",    # 標準的な視点
    camera_widths=crop_size,
    camera_heights=crop_size,
)

frames = []

# --- 2. ゼロショット用の「目標画像」を取得 (npzからロード) ---
print("npzファイルから目標データをロード中...")

# 1. データの読み込み
goal = np.load("franka_example_goal.npz")
# observationsは (1, H, W, C) なので、最初の1枚 [0] を取り出す
target_img_np_raw = goal["observations"][0]
# statesは (1, 7)
target_state_raw = goal["states"][0]

# 表示用
target_img_display = target_img_np_raw.copy()

# V-JEPA入力用に前処理
target_img_np = target_img_np_raw.copy()
target_img_np = np.expand_dims(target_img_np, axis=0) # [1, H, W, C]
target_img_pt = transform(target_img_np).unsqueeze(0).to(device) # [1, 3, T, H, W]

# T(時間方向)を2にする処理（元のコードに合わせてcat）
target_img_pt = torch.cat((target_img_pt, target_img_pt), dim=2)

with torch.no_grad():
    # V-JEPAに通して特徴量(z_goal)を抽出
    h_goal = encoder(target_img_pt) 
    # 最後のトークン群を抽出
    z_goal = h_goal[:, -tokens_per_frame:]

print(f"ロード完了: 目標画像サイズ {target_img_display.shape}, 目標State {target_state_raw}")

'''
# --- 2. ゼロショット用の「目標画像」を取得 ---
print("目標画像を取得中...")
# 一度理想の場所にブロックを置いて、その画像を z_goal として保存する
# (ここでは今のリセット直後の位置を一旦ゴールとみなす例)
goal_obs = env.reset()
target_img_display = np.flipud(goal_obs["agentview_image"]).copy() # 表示用

# V-JEPA入力用に前処理
target_img_np = np.flipud(goal_obs["agentview_image"]).copy() # 俯瞰画像を取得 np.flipudで反転
target_img_np = np.expand_dims(target_img_np, axis=0)
target_img_pt = transform(target_img_np).unsqueeze(0).to(device) # [1, 3, T, H, W]
target_img_pt = torch.cat((target_img_pt, target_img_pt), dim=2)

with torch.no_grad():
    # V-JEPAに通して特徴量(z_goal)を抽出
    h_goal = encoder(target_img_pt) 
    # 最後のトークン群を抽出
    z_goal = h_goal[:, -tokens_per_frame:]
'''

# --- 3. アームを別の場所からリスタート ---
obs = env.reset() # 別のランダムな位置から開始
initial_img_display = np.flipud(obs["agentview_image"]).copy() # 表示用

frames = []
print("シミュレーション開始...")

# --- 4. 実行ループ (MPC) ---
for step in tqdm(range(200)):
    # frameをappend
    raw_img = obs["agentview_image"]
    frame = np.flipud(raw_img).copy() 
    frames.append(frame) # 混じり物がない純粋な2D画像を保存
    
    # 現在の画像を取得
    current_img_np = np.flipud(obs["agentview_image"]).copy()
    current_img_np = np.expand_dims(current_img_np, axis=0)
    current_img_pt = transform(current_img_np).unsqueeze(0).to(device) # [1, 3, T, H, W]
    current_img_pt = torch.cat((current_img_pt, current_img_pt), dim=2)
    
    with torch.no_grad():
        # 現在の特徴量(z_n)を抽出
        h_n = encoder(current_img_pt)
        z_n = h_n[:, :tokens_per_frame]
        s_n = get_state(obs)
        s_n = torch.from_numpy(s_n).unsqueeze(0).to(device)
        print(s_n.shape)
        
        # V-JEPA 2 AC + MPC (WorldModel) で次の1手を予測
        action = world_model.infer_next_action(z_n, s_n, z_goal)
        
    # アクションを実行
    action_np = action.cpu().numpy()
    obs, reward, done, info = env.step(action_np[0])
    
    if step % 10 == 0:
        print(f"Step {step}: Action {action_np[0, :3]}")

    if done:
        print("目標に到達しました！")
        break

env.close()

# --- 初期状態の画像表示 ---
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(initial_img_display)
plt.title("Initial State Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(target_img_display)
plt.title("Target Goal Image")
plt.axis('off')
plt.show()

# 5. アニメーションの作成
print("Creating animation...")
fig = plt.figure(figsize=(6, 6))
plt.axis('off')
im = plt.imshow(frames[0])

def update(img):
    im.set_data(img)
    return [im]

ani = animation.ArtistAnimation(fig, [[plt.imshow(f, animated=True)] for f in frames], 
                                interval=50, blit=True, repeat_delay=1000)

# 4. HTMLとして表示
plt.close() # 余計な静止画を消す
HTML(ani.to_jshtml())

[1m[32m[robosuite INFO] [0mLoading controller configuration from: /usr/local/lib/python3.11/site-packages/robosuite/controllers/config/robots/default_panda.json (composite_controller_factory.py:121)


[INFO    ][2026-01-09 08:41:32][robosuite_logs      ][load_composite_controller_config] Loading controller configuration from: /usr/local/lib/python3.11/site-packages/robosuite/controllers/config/robots/default_panda.json


  self.gen = func(*args, **kwds)
[1m[32m[robosuite INFO] [0mLoading controller configuration from: /usr/local/lib/python3.11/site-packages/robosuite/controllers/config/robots/default_panda.json (composite_controller_factory.py:121)


npzファイルから目標データをロード中...
ロード完了: 目標画像サイズ (256, 256, 3), 目標State [-2.4413420e-02 -9.9598356e-03  1.0250156e+00 -3.1383517e+00
 -1.4723109e-01  1.0011774e-02 -3.9702363e-04]
[INFO    ][2026-01-09 08:41:32][robosuite_logs      ][load_composite_controller_config] Loading controller configuration from: /usr/local/lib/python3.11/site-packages/robosuite/controllers/config/robots/default_panda.json
シミュレーション開始...


  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([1, 7])
[INFO    ][2026-01-09 08:41:34][utils.mpc_utils     ][cem                      ] new mean: tensor([ 0.0460,  0.0091, -0.0274, -0.3606], device='cuda:0') tensor([0.1328, 0.1342, 0.1402, 1.3161], device='cuda:0')
[INFO    ][2026-01-09 08:41:35][utils.mpc_utils     ][cem                      ] new mean: tensor([ 0.0549,  0.0179, -0.0321, -0.0999], device='cuda:0') tensor([0.1275, 0.1149, 0.1276, 0.9596], device='cuda:0')
Step 0: Action [ 0.01146723 -0.03289493 -0.03577987]
torch.Size([1, 7])
[INFO    ][2026-01-09 08:41:36][utils.mpc_utils     ][cem                      ] new mean: tensor([ 0.0288, -0.0206,  0.0073,  0.3337], device='cuda:0') tensor([0.1393, 0.1407, 0.1414, 1.2534], device='cuda:0')
[INFO    ][2026-01-09 08:41:38][utils.mpc_utils     ][cem                      ] new mean: tensor([ 0.0060, -0.0505, -0.0045,  0.0538], device='cuda:0') tensor([0.1231, 0.1322, 0.1279, 1.0657], device='cuda:0')
torch.Size([1, 7])
[INFO    ][2026-01-09 08:41:39][utils.mpc_util