In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import pickle
from glob import glob
import matplotlib.pyplot as plt
from collections import Counter
import torch
import tarfile

import numpy as np
import zarr
from tqdm import tqdm

from pathlib import Path
from furniture_bench.robot.robot_state import filter_and_concat_robot_state
from furniture_bench.perception.image_utils import resize_crop

from vip import load_vip

In [13]:
root = Path("/home/larsankile/furniture-diffusion/data/")
randomness = "low"
furniture = "one_leg"
extension = ".tar.gz"
obs_type = "image"
env_type = "real"
filename = furniture + extension


input_file = root / "raw" / env_type / obs_type / randomness / filename
output_file = (
    root / "processed" / env_type / obs_type / randomness / furniture / "data.zarr"
)

In [5]:
max_samples = 200

raw_data, n_samples = [], 0
with tarfile.open(input_file, "r:gz") as tar:
    for member in tqdm(tar, desc=f"Extracting {furniture}", total=max_samples):
        if (
            member.isfile() and ".pkl" in member.name
        ):  # Replace 'your_condition' with actual condition
            with tar.extractfile(member) as f:
                if f is not None:
                    content = f.read()
                    data = pickle.loads(content)
                    raw_data.append(data)
                    n_samples += 1

                    if n_samples >= max_samples:
                        break

Extracting one_leg: 100%|██████████| 50/50 [00:41<00:00,  1.19it/s]


In [15]:
input_dir = Path("/home/larsankile/furniture-diffusion/data/raw/sim/full/one_leg/low")
output_file = Path(
    "/home/larsankile/furniture-diffusion/data/processed/sim/image/low/one_leg/data.zarr"
)

In [26]:
files = list(input_dir.rglob("**/*.pkl"))

raw_data = []

for file in tqdm(files):
    with open(file, "rb") as f:
        data = pickle.load(f)
        raw_data.append(data)

100%|██████████| 73/73 [00:02<00:00, 27.64it/s]


## Extract features

In [16]:
vip = load_vip(device_id=1).module

vip.device

'cuda:1'

In [22]:
def get_features(img_batch):
    with torch.no_grad():
        img_tensor = torch.tensor(
            img_batch, dtype=torch.float32, device="cuda:1"
        ).permute(0, 3, 1, 2)
        features = vip(img_tensor).cpu().numpy()
    return features

In [28]:
batch_size = 256
observations = []
actions = []
episode_ends = []

end_index = 0

for data in tqdm(raw_data):
    img1_batch = []
    img2_batch = []
    for obs, action in zip(data["observations"], data["actions"]):
        robot_state = filter_and_concat_robot_state(obs["robot_state"])

        img1_batch.append(obs["color_image1"])
        img2_batch.append(obs["color_image2"])

        actions.append(action)

        if len(img1_batch) == batch_size:
            img1_features = get_features(np.stack(img1_batch, axis=0))
            img2_features = get_features(np.stack(img2_batch, axis=0))

            for f1, f2 in zip(img1_features, img2_features):
                observation = np.concatenate((robot_state, f1, f2))
                observations.append(observation)

            img1_batch = []
            img2_batch = []

        end_index += 1

    # Handle any remaining images within each trajectory
    if img1_batch:
        img1_features = get_features(np.stack(img1_batch, axis=0))
        img2_features = get_features(np.stack(img2_batch, axis=0))

        for f1, f2 in zip(img1_features, img2_features):
            observation = np.concatenate((robot_state, f1, f2))
            observations.append(observation)

    episode_ends.append(end_index)

observations = np.array(observations)
actions = np.array(actions)
episode_ends = np.array(episode_ends)

  0%|          | 0/150 [00:00<?, ?it/s]

100%|██████████| 150/150 [04:18<00:00,  1.72s/it]


In [29]:
zarr.save(
    output_file,
    observations=observations,
    actions=actions,
    episode_ends=episode_ends,
)

## Extract and resize the images

In [27]:
agent_pos = []
image1 = []
image2 = []
actions = []
episode_ends = []

end_index = 0

for data in tqdm(raw_data):
    img1_batch = []
    img2_batch = []
    for obs, action in zip(data["observations"], data["actions"]):
        robot_state = filter_and_concat_robot_state(obs["robot_state"])
        agent_pos.append(robot_state)

        img1 = obs["color_image1"]
        img2 = obs["color_image2"]

        if img1.shape != (224, 224, 3):
            img1 = resize_crop(img1)
            img2 = resize_crop(img2)

        image1.append(img1)
        image2.append(img2)

        actions.append(action)

        end_index += 1

    episode_ends.append(end_index)

agent_pos = np.array(agent_pos)
image1 = np.array(image1)
image2 = np.array(image2)
actions = np.array(actions)
episode_ends = np.array(episode_ends)

agent_pos.shape, image1.shape, image2.shape, actions.shape, episode_ends.shape

100%|██████████| 73/73 [00:00<00:00, 530.67it/s]


((35573, 14), (35573, 224, 224, 3), (35573, 224, 224, 3), (35573, 8), (73,))

In [28]:
zarr.save(
    output_file,
    agent_pos=agent_pos,
    image1=image1,
    image2=image2,
    actions=actions,
    episode_ends=episode_ends,
)