In [1]:
%load_ext autoreload
%autoreload 2

In [27]:
import pickle
from glob import glob
import matplotlib.pyplot as plt
from collections import Counter
import torch
import tarfile

import numpy as np
import zarr
from tqdm import tqdm

from pathlib import Path
from furniture_bench.robot.robot_state import filter_and_concat_robot_state

from vip import load_vip

In [3]:
root = Path("/home/larsankile/furniture-diffusion/data/")
randomness = "low"
furniture = "lamp"
extension = ".tar.gz"
filename = furniture + extension


input_file = root / "raw" / "real" / "image" / randomness / filename
output_file = (
    root / "processed" / "real" / "feature" / randomness / furniture / "data.zarr"
)

In [9]:
raw_data = []

with tarfile.open(input_file, "r:gz") as tar:
    for member in tqdm(tar, desc="Extracting lamp", total=150):
        if (
            member.isfile() and ".pkl" in member.name
        ):  # Replace 'your_condition' with actual condition
            with tar.extractfile(member) as f:
                if f is not None:
                    content = f.read()
                    data = pickle.loads(content)
                    raw_data.append(data)

Extracting lamp:   0%|          | 0/150 [00:00<?, ?it/s]

Extracting lamp: 151it [02:45,  1.10s/it]                         


In [16]:
vip = load_vip(device_id=1).module

vip.device

'cuda:1'

In [22]:
def get_features(img_batch):
    with torch.no_grad():
        img_tensor = torch.tensor(
            img_batch, dtype=torch.float32, device="cuda:1"
        ).permute(0, 3, 1, 2)
        features = vip(img_tensor).cpu().numpy()
    return features

In [28]:
batch_size = 256
observations = []
actions = []
episode_ends = []

end_index = 0

for data in tqdm(raw_data):
    img1_batch = []
    img2_batch = []
    for obs, action in zip(data["observations"], data["actions"]):
        robot_state = filter_and_concat_robot_state(obs["robot_state"])

        img1_batch.append(obs["color_image1"])
        img2_batch.append(obs["color_image2"])

        actions.append(action)

        if len(img1_batch) == batch_size:
            img1_features = get_features(np.stack(img1_batch, axis=0))
            img2_features = get_features(np.stack(img2_batch, axis=0))

            for f1, f2 in zip(img1_features, img2_features):
                observation = np.concatenate((robot_state, f1, f2))
                observations.append(observation)

            img1_batch = []
            img2_batch = []

        end_index += 1

    # Handle any remaining images within each trajectory
    if img1_batch:
        img1_features = get_features(np.stack(img1_batch, axis=0))
        img2_features = get_features(np.stack(img2_batch, axis=0))

        for f1, f2 in zip(img1_features, img2_features):
            observation = np.concatenate((robot_state, f1, f2))
            observations.append(observation)

    episode_ends.append(end_index)

observations = np.array(observations)
actions = np.array(actions)
episode_ends = np.array(episode_ends)

  0%|          | 0/150 [00:00<?, ?it/s]

100%|██████████| 150/150 [04:18<00:00,  1.72s/it]


In [29]:
zarr.save(
    output_file,
    observations=observations,
    actions=actions,
    episode_ends=episode_ends,
)