In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
from glob import glob
from pathlib import Path
import os
from tqdm import tqdm
import pandas as pd
import zarr
import numpy as np

from src.models.vision import get_encoder
from src.data.process_demos import encode_demo
from src.visualization.render_mp4 import create_mp4

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
base_dir = Path(os.environ["FURNITURE_DATA_DIR"])

rollout_dir = base_dir / "raw" / "sim_rollouts"

file_path = rollout_dir / "index.csv"

## Index the raw rollout data

Now done in a standalone script `src.data.index_rollouts`

## Augment an existing Zarr array with new data from the index

In [4]:
base_dir = Path("/data/scratch/ankile/furniture-data/data")

In [5]:
zarr_path = (
    base_dir
    / "processed"
    / "sim"
    / "feature_separate_small"
    / "vip"
    / "one_leg"
    / "data_aug.zarr"
)

store = zarr.open(str(zarr_path), mode="a")

In [6]:
if "rollout_paths" not in store:
    print("Creating rollout_paths dataset")
    store.create_dataset("rollout_paths", shape=(0,), dtype=str)
else:
    print("rollout_paths dataset already exists")

# Remove the skills dataset if it exists
if "skills" in store:
    print("Removing skills dataset")
    del store["skills"]
else:
    print("skills dataset does not exist")

rollout_paths dataset already exists
skills dataset does not exist


In [7]:
# Read in the index file as a dataframe
index = pd.read_csv(file_path)

index = index[index["success"] == True]

# Get the paths to all the successful rollouts
paths = index["path"].values

# Compare with the paths already in the zarr file
zarr_paths = store["rollout_paths"][:]
paths = [p for p in paths if p not in zarr_paths]

len(paths)

152

In [8]:
# Just sanity check the index by loading a rollout
with open(paths[0], "rb") as f:
    rollout = pickle.load(f)

vid1 = [o["color_image1"] for o in rollout["observations"]]
vid2 = [o["color_image2"] for o in rollout["observations"]]
vid = np.concatenate([vid1, vid2], axis=2)

end_idx = np.argmax(rollout["rewards"]) + 1

create_mp4(vid[:end_idx], "test.mp4")

100%|██████████| 455/455 [00:00<00:00, 799.92it/s]

File saved as test.mp4





In [9]:
# Get an encoder
encoder = get_encoder("vip", freeze=True, device="cuda:0")
batch_size = 128

In [10]:
# Iterate over the paths and add them to the zarr file
end_index = store["episode_ends"][-1]

for path in tqdm(paths):
    with open(path, "rb") as f:
        data = pickle.load(f)

    end_idx = np.argmax(data["rewards"]) + 1

    store["action"].append(data["actions"][:end_idx])
    store["rewards"].append(data["rewards"][:end_idx])

    store["episode_ends"].append([end_index := end_index + end_idx])
    store["furniture"].append([data["furniture"]])

    obs = data["observations"][:end_idx]
    demo_robot_states, demo_features1, demo_features2 = encode_demo(
        encoder, batch_size, obs
    )
    store["robot_state"].append(demo_robot_states)
    store["feature1"].append(demo_features1)
    store["feature2"].append(demo_features2)
    store["rollout_paths"].append([path])

 22%|██▏       | 34/152 [03:28<12:55,  6.58s/it]Exception ignored in: <function _xla_gc_callback at 0x7f3e93acad30>
Traceback (most recent call last):
  File "/data/scratch/ankile/miniconda3/envs/rlgpu/lib/python3.8/site-packages/jax/_src/lib/__init__.py", line 103, in _xla_gc_callback
    def _xla_gc_callback(*args):
KeyboardInterrupt: 
 24%|██▎       | 36/152 [03:38<11:25,  5.91s/it]

## Merge together the two index files

In [14]:
old_index = pd.read_csv("/data/pulkitag/data/ankile/furniture-data/data/raw/sim_rollouts/index.csv")
old_index

Unnamed: 0,path,furniture,success
0,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
1,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
2,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
3,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
4,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
...,...,...,...
8809,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
8810,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
8811,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
8812,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False


In [13]:
new_index = pd.read_csv("/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim_rollouts/index.csv")
new_index

Unnamed: 0,path,furniture,success
0,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
1,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
2,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
3,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
4,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
...,...,...,...
506,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
507,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
508,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
509,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False


In [12]:
old_index.path = old_index.path.str.replace("/data/pulkitag/data/ankile/furniture-data/data", "/data/scratch-oc40/pulkitag/ankile/furniture-data")

old_index.path[0]

'/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim_rollouts/2024-01-02_18-33-01/rollout_7.pkl'

In [15]:
# Merge the two index files
new_index = pd.concat([old_index, new_index], axis=0, ignore_index=True).reset_index()

new_index

Unnamed: 0,index,path,furniture,success
0,0,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
1,1,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
2,2,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
3,3,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
4,4,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
...,...,...,...,...
9320,9320,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
9321,9321,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
9322,9322,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
9323,9323,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False


In [16]:
# Save the new index back to the original file
new_index.to_csv("/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim_rollouts/index.csv")

In [17]:
new_index.success.mean()

0.07184986595174263

In [50]:
# Now, after more indexing, see again how it looks
new_index = pd.read_csv("/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim_rollouts/index.csv")

new_index

Unnamed: 0,path,furniture,success
0,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
1,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
2,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
3,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
4,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
...,...,...,...
12167,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12168,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,True
12169,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12170,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False


In [41]:
bad_start = [i for i, elem in enumerate(new_index["Unnamed: 0"]) if elem.startswith("/data")][0]

new_index.iloc[bad_start:, :]

Unnamed: 0.1,Unnamed: 0,path,furniture,success
12084,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
12085,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
12086,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
12087,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
12088,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
...,...,...,...,...
12167,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
12168,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,True,
12169,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,
12170,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False,


In [42]:
new_index.iloc[bad_start:, 1:4] = new_index.iloc[bad_start:, 0:3].values

In [43]:
new_index.iloc[bad_start:, :]

Unnamed: 0.1,Unnamed: 0,path,furniture,success
12084,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12085,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12086,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12087,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12088,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
...,...,...,...,...
12167,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12168,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,True
12169,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12170,/data/scratch-oc40/pulkitag/ankile/furniture-d...,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False


In [44]:
new_index2 = new_index.drop(columns=["Unnamed: 0"])

In [45]:
new_index2

Unnamed: 0,path,furniture,success
0,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
1,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
2,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
3,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
4,/data/pulkitag/data/ankile/furniture-data/data...,one_leg,False
...,...,...,...
12167,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12168,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,True
12169,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12170,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False


In [51]:
new_index2.path = new_index2.path.str.replace("/data/pulkitag/data/ankile/furniture-data/data", "/data/scratch-oc40/pulkitag/ankile/furniture-data")

new_index2

Unnamed: 0,path,furniture,success
0,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
1,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
2,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
3,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
4,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
...,...,...,...
12167,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12168,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,True
12169,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False
12170,/data/scratch-oc40/pulkitag/ankile/furniture-d...,one_leg,False


In [52]:
new_index2.to_csv("/data/scratch-oc40/pulkitag/ankile/furniture-data/raw/sim_rollouts/index.csv", index=False)

In [58]:
new_index2.success.astype(bool).sum()

984