# Investigate rollouts and resulting datasets

## Round table

### Visualize rollouts

In [None]:
from pathlib import Path
from src.visualization.render_mp4 import (
    mp4_from_pickle_jupyter,
    unpickle_data,
    pickle_data,
)
from src.common.files import get_raw_paths
import random
from tqdm import tqdm

import numpy as np

base_dir = Path("tmp")

In [None]:
paths = get_raw_paths(
    environment="sim",
    demo_source="rollout",
    demo_outcome="success",
    task="round_table",
    randomness="low",
)

random.shuffle(paths)

len(paths), paths[:3]

In [None]:
for path in paths[:1]:
    mp4_from_pickle_jupyter(
        path, filename=(base_dir / path.stem).with_suffix(".mp4"), fps=20
    )

### Look at state-space coverage

### Learn from partial successes as a curriculum

In [None]:
paths = get_raw_paths(
    environment="sim",
    demo_source="rollout",
    demo_outcome="failure",
    task="round_table",
    randomness="med",
)

# random.shuffle(paths)
paths = sorted(paths)

data_keys = ["observations", "actions", "rewards"]
meta_keys = ["success", "furniture", "action_type"]

len(paths), paths[:3]

In [None]:
new_path = Path(str(paths[0]).replace("failure", "partial_success"))

new_path.parent.mkdir(parents=True, exist_ok=True)

In [None]:
# Iterate over the paths and render and find the ones where ther is at least one reward
truncation_tail = 64
n_partial = 0

n = 0

it = tqdm(paths[n * 114 : (n + 1) * 114], postfix={"partial_success": n_partial})
for path in it:
    data = unpickle_data(path)

    if sum(data["rewards"]) < 1:
        continue

    # Find the index of the last reward
    rew = np.array(data["rewards"])
    last_reward = np.where(rew > 0)[0][-1]

    # Make a copy of the data and truncate the `data_keys`
    truncated_data = {}
    truncated_data["observations"] = data["observations"][
        : last_reward + truncation_tail + 1
    ]
    truncated_data["actions"] = data["actions"][: last_reward + truncation_tail]
    truncated_data["rewards"] = data["rewards"][: last_reward + truncation_tail]

    # Add the `meta_keys`
    truncated_data.update({k: data[k] for k in meta_keys})
    truncated_data["success"] = "partial_success"
    truncated_data["truncated"] = True
    truncated_data["truncation_tail"] = truncation_tail

    new_path = Path(str(path).replace("failure", "partial_success"))
    new_path.parent.mkdir(parents=True, exist_ok=True)

    pickle_data(truncated_data, new_path)
    n_partial += 1
    it.set_postfix({"partial_success": n_partial})

## Lamp

### Visualize the rollouts

In [None]:
from pathlib import Path
from src.visualization.render_mp4 import (
    mp4_from_pickle_jupyter,
    unpickle_data,
    pickle_data,
)
from src.common.files import get_raw_paths
import random
from tqdm import tqdm

base_dir = Path("tmp")

In [None]:
paths = get_raw_paths(
    environment="sim",
    demo_source="rollout",
    demo_outcome="success",
    task="lamp",
    randomness="low",
)

# random.shuffle(paths)

paths = sorted(paths, reverse=True)

len(paths), paths[:3]

In [None]:
for i, path in enumerate(paths[:10], start=1):
    mp4_from_pickle_jupyter(
        path, filename=(base_dir / path.stem).with_suffix(".mp4"), fps=20
    )

## Round table

In [None]:
from pathlib import Path
from src.visualization.render_mp4 import (
    mp4_from_pickle_jupyter,
    unpickle_data,
    pickle_data,
)
from src.common.files import get_raw_paths
import random
from tqdm import tqdm

base_dir = Path("tmp")

In [None]:
paths = get_raw_paths(
    environment="sim",
    demo_source="rollout",
    demo_outcome="success",
    task="round_table",
    randomness="low",
)

# random.shuffle(paths)

paths = sorted(paths, reverse=False)

len(paths), paths[:3]

In [None]:
for i, path in enumerate(paths[:1], start=1):
    mp4_from_pickle_jupyter(
        path, filename=(base_dir / path.stem).with_suffix(".mp4"), fps=20
    )

### Plot coverage of new trajectories

In [None]:
import zarr
import matplotlib.pyplot as plt
import numpy as np


from src.common.files import get_processed_paths

In [None]:
rollout_path, teleop_path = sorted(
    get_processed_paths(
        environment="sim",
        demo_source=["teleop", "rollout"],
        demo_outcome="success",
        task="round_table",
        randomness="low",
    )
)

rollout_path, teleop_path

In [None]:
z_rollout = zarr.open(str(rollout_path), mode="r")
z_teleop = zarr.open(str(teleop_path), mode="r")

ends_rollout = z_rollout["episode_ends"][:]
ends_teleop = z_teleop["episode_ends"][:]

pos_teleop = z_teleop["robot_state"][:, :3]
pos_rollout = z_rollout["robot_state"][:, :3]

# Split the data into episodes
pos_teleop = np.split(pos_teleop, ends_teleop[:-1])
pos_rollout = np.split(pos_rollout, ends_rollout[:-1])

# # Concat them together again
pos_teleop = np.concatenate(pos_teleop)
pos_rollout = np.concatenate(pos_rollout)

In [None]:
len(ends_teleop), len(pos_teleop), len(ends_rollout), len(pos_rollout)

### Plot the state-space coverage in 3D

In [None]:
fig = plt.figure(figsize=(6, 6))

# Add a subplot for teleop data in the left side of the figure
ax1 = fig.add_subplot(
    111, projection="3d"
)  # Changed from 111 to 121 for a 1x2 grid, first position
ax1.scatter(*pos_teleop.T, label=f"Teleop (n={len(ends_teleop)})", s=0.1)
ax1.legend(frameon=False)
ax1.set_title("Teleop data only")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
ax1.set_zlabel("z")

plt.show()

In [None]:
fig = plt.figure(figsize=(12, 6))

# Add a subplot for teleop data in the left side of the figure
ax1 = fig.add_subplot(
    121, projection="3d"
)  # Changed from 111 to 121 for a 1x2 grid, first position
ax1.scatter(*pos_teleop.T, label=f"Teleop (n={len(ends_teleop)})", s=0.1)
ax1.legend(frameon=False)
ax1.set_title("Teleop data")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
ax1.set_zlabel("z")

# Add a subplot for Rollout data in the right side of the figure
ax2 = fig.add_subplot(
    122, projection="3d"
)  # Changed from 111 to 122 for a 1x2 grid, second position
# ax2.scatter(*pos_teleop.T, label=f"Teleop (n={len(ends_teleop)})", s=0.2, alpha=0.2)
ax2.scatter(*pos_rollout.T, label=f"Rollout (n={len(ends_rollout)})", s=0.2, alpha=0.5)
ax2.legend(frameon=False)
ax2.set_title("Rollout data")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_zlabel("z")

plt.savefig("figs/teleop_rollout.png")

plt.show()

In [None]:
fig = plt.figure(figsize=(6, 6))

# Add a subplot for Rollout data in the right side of the figure
ax2 = fig.add_subplot(
    111, projection="3d"
)  # Changed from 111 to 122 for a 1x2 grid, second position
# ax2.scatter(*pos_teleop.T, label=f"Teleop (n={len(ends_teleop)})", s=0.2, alpha=0.2)
ax2.scatter(*pos_rollout.T, label=f"Rollout (n={len(ends_rollout)})", s=0.2, alpha=0.5)
ax2.legend(frameon=False)
ax2.set_title("Rollout data")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.set_zlabel("z")

plt.savefig("figs/teleop_rollout.png")

plt.show()

## Look at MLP rollouts

In [None]:
from pathlib import Path
from src.visualization.render_mp4 import (
    mp4_from_pickle_jupyter,
    unpickle_data,
    pickle_data,
)
from src.common.files import get_raw_paths
import random
from tqdm import tqdm

base_dir = Path("tmp")
pickles = list((base_dir / "mlp_rollouts").rglob("*.pkl"))

pickles

In [None]:
for i, path in enumerate(pickles, start=1):
    print(f"Rendering {i}/{len(pickles)}: {path}")
    mp4_from_pickle_jupyter(
        path, filename=(base_dir / path.stem).with_suffix(".mp4"), fps=20
    )

## Fix stored state representation in Round Table rollouts

In [3]:
%env DATA_DIR_RAW=/data/scratch-oc40/pulkitag/ankile/furniture-data

env: DATA_DIR_RAW=/data/scratch-oc40/pulkitag/ankile/furniture-data


In [2]:
from pathlib import Path
from src.visualization.render_mp4 import (
    mp4_from_pickle_jupyter,
    unpickle_data,
    pickle_data,
)
from src.common.files import get_raw_paths
import random
from tqdm import tqdm

from furniture_bench.robot.robot_state import ROBOT_STATES, ROBOT_STATE_DIMS

In [3]:
ROBOT_STATES, ROBOT_STATE_DIMS

(['ee_pos', 'ee_quat', 'ee_pos_vel', 'ee_ori_vel', 'gripper_width'],
 {'ee_pos': 3,
  'ee_quat': 4,
  'ee_pos_vel': 3,
  'ee_ori_vel': 3,
  'joint_positions': 7,
  'joint_velocities': 7,
  'joint_torques': 7,
  'gripper_width': 1})

In [4]:
base_dir = Path("tmp")

rollout_paths = get_raw_paths(
    environment="sim",
    demo_source="rollout",
    demo_outcome="success",
    task="round_table",
    randomness="med",
)

teleop_paths = get_raw_paths(
    environment="sim",
    demo_source="teleop",
    demo_outcome="success",
    task="round_table",
    randomness="low",
)

# random.shuffle(rollout_paths)

rollout_paths = sorted(rollout_paths, reverse=False)

len(rollout_paths), rollout_paths[:3]

Found the following paths:
    /data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim/round_table/rollout/med/success/*.pkl*
Found the following paths:
    /data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim/round_table/teleop/low/success/*.pkl*


(161,
 [PosixPath('/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim/round_table/rollout/med/success/2024-02-23T02:21:31.pkl.xz'),
  PosixPath('/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim/round_table/rollout/med/success/2024-02-23T08:02:11.pkl.xz'),
  PosixPath('/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim/round_table/rollout/med/success/2024-02-23T08:17:46.pkl.xz')])

In [5]:
data = unpickle_data(teleop_paths[0])

print(data.keys())

[(k, v.shape) for k, v in data["observations"][0]["robot_state"].items()]

dict_keys(['observations', 'actions', 'rewards', 'skills', 'success', 'furniture', 'error', 'error_description', 'augment_states'])


[('ee_pos', (3,)),
 ('ee_quat', (4,)),
 ('ee_pos_vel', (3,)),
 ('ee_ori_vel', (3,)),
 ('gripper_width', ()),
 ('joint_positions', (7,)),
 ('joint_velocities', (7,)),
 ('joint_torques', (9,))]

In [9]:
# data = unpickle_data(rollout_paths[0])
data = unpickle_data(
    "/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim/round_table/rollout/med/success_backup/2024-02-23T02:21:31.pkl.xz"
)

print(data.keys())

# [(k, v.shape) for k, v in data["observations"][0]["robot_state"].items()]
print(data["observations"][0]["robot_state"].shape)

dict_keys(['observations', 'actions', 'rewards', 'success', 'furniture', 'action_type', 'augment_states'])
(14,)


In [22]:
rollout_paths = get_raw_paths(
    environment="sim",
    demo_source="rollout",
    demo_outcome="success",
    task="round_table",
    randomness="med",
)


for path in tqdm(rollout_paths):
    data = unpickle_data(path)

    # Check if we have already converted this one
    if isinstance(data["observations"][0]["robot_state"], dict):
        continue

    for obs in data["observations"]:
        robot_state_flat = obs["robot_state"]
        robot_state_dict = {}

        start = 0
        for state, dim in map(lambda s: (s, ROBOT_STATE_DIMS[s]), ROBOT_STATES):
            end = start + dim
            robot_state_dict[state] = robot_state_flat[start:end]
            start = end

        obs["robot_state"] = robot_state_dict

    pickle_data(data, path)

 27%|██▋       | 43/158 [1:03:10<2:48:56, 88.14s/it]


KeyboardInterrupt: 