In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
from glob import glob
import matplotlib.pyplot as plt
from collections import Counter
import torch
import tarfile

import numpy as np
import zarr
from tqdm import tqdm

from pathlib import Path

from vip import load_vip

In [3]:
root = Path("/home/larsankile/furniture-diffusion/data/")
randomness = "low"
furniture = "lamp"
extension = ".tar.gz"
filename = furniture + extension


input_file = root / "raw" / "real" / "image" / randomness / filename
output_file = (
    root / "processed" / "real" / "feature" / randomness / furniture / "data.zarr"
)

In [9]:
raw_data = []

with tarfile.open(input_file, "r:gz") as tar:
    for member in tqdm(tar, desc="Extracting lamp", total=150):
        if (
            member.isfile() and ".pkl" in member.name
        ):  # Replace 'your_condition' with actual condition
            with tar.extractfile(member) as f:
                if f is not None:
                    content = f.read()
                    data = pickle.loads(content)
                    raw_data.append(data)

Extracting lamp:   0%|          | 0/150 [00:00<?, ?it/s]

Extracting lamp: 151it [02:45,  1.10s/it]                         


In [16]:
vip = load_vip(device_id=1).module

vip.device

'cuda:1'

In [17]:
def get_features(img):
    with torch.no_grad():
        return (
            vip(
                torch.tensor(img, dtype=torch.float32, device="cuda:1")
                .permute(2, 0, 1)
                .unsqueeze(0)
            )
            .squeeze(0)
            .cpu()
            .numpy()
        )

In [18]:
observations = []
actions = []
episode_ends = []

end_index = 0
for data in tqdm(raw_data):
    for obs, action in zip(data["observations"], data["actions"]):
        # Each observation is just a concatenation of the robot state and the object state.
        # Collect the robot state.
        robot_state = np.concatenate(list(obs["robot_state"].values()), -1)

        # Convert image to tensor and pass through VIP.
        img1 = get_features(obs["color_image1"])
        img2 = get_features(obs["color_image2"])

        # Add the observation to the overall list.
        observation = np.concatenate((robot_state, img1, img2))

        observations.append(observation)

        # Add the action to the overall list.
        actions.append(action)

        # Increment the end index.
        end_index += 1

    # Add the end index to the overall list.
    episode_ends.append(end_index)

# Convert the lists to numpy arrays.
observations = np.array(observations)
actions = np.array(actions)
episode_ends = np.array(episode_ends)

  0%|          | 0/150 [00:00<?, ?it/s]

  0%|          | 0/150 [00:06<?, ?it/s]


KeyboardInterrupt: 

In [46]:
zarr.save(
    output_file,
    observations=observations,
    actions=actions,
    episode_ends=episode_ends,
)