In [1]:
import pytorch3d.transforms as pt
import zarr
import numpy as np
import torch

In [2]:
def numpy_quaternion_to_6d(quaternions_np):
    # Convert the numpy array of quaternions to a PyTorch tensor
    quaternions = torch.from_numpy(quaternions_np).float()

    # Move the real part from the back to the front
    quaternions = torch.cat([quaternions[:, -1:], quaternions[:, :-1]], dim=1)

    # Convert each quaternion to a rotation matrix
    rot_mats = pt.quaternion_to_matrix(quaternions)

    # Extract the first two columns of each rotation matrix
    six_d_representations = pt.matrix_to_rotation_6d(rot_mats)

    # Convert the results back to a numpy array
    return six_d_representations.numpy()

In [3]:
dataset = zarr.open(
    "/data/scratch/ankile/furniture-data/processed/sim/feature/vip/combined.zarr",
    mode="a",
)

In [4]:
rot_quat = dataset["action"][:, 3:7]

rot_quat.shape

(1128470, 4)

In [5]:
rot_6d = numpy_quaternion_to_6d(rot_quat)

rot_6d.shape

(1128470, 6)

In [6]:
action_6d = np.concatenate(
    [dataset["action"][:, :3], rot_6d, dataset["action"][:, 7:]], axis=1
)

action_6d.shape

(1128470, 10)

In [7]:
dataset["action_6d"] = action_6d

In [8]:
# Get the stats and store them
from src.dataset.normalizer import get_data_stats

In [10]:
stats = get_data_stats(action_6d)

stats

{'min': array([-0.11999829, -0.11999908, -0.08132574, -0.35798502, -0.99999934,
        -0.70630395, -0.99999964, -0.28959465, -0.89785737, -1.        ],
       dtype=float32),
 'max': array([0.11999787, 0.11999953, 0.07999872, 1.        , 0.9999992 ,
        0.9999981 , 0.9999772 , 1.        , 0.49280044, 1.        ],
       dtype=float32)}

In [9]:
stats = get_data_stats(action_6d)

stats

{'min': array([-0.11999926, -0.1199997 , -0.08132574, -0.99995553, -0.9999999 ,
        -0.9951642 , -0.99999964, -0.99929166, -0.89785737, -1.        ],
       dtype=float32),
 'max': array([0.11999996, 0.11999957, 0.09052899, 1.        , 0.9999998 ,
        0.9999981 , 0.99999875, 1.        , 0.59187496, 1.        ],
       dtype=float32)}