In [87]:
import pickle
from softgym.softgym.utils import visualization
from torch_geometric.nn import fps
import os
import torch
import numpy as np
from pyquaternion import Quaternion

In [4]:
DATA_DIR = "/data/seita/softgym_mm/data_demo/MMOneSphere_v01_BClone_filtered_ladle_algorithmic_v04_nVars_2000_obs_combo_act_translation_axis_angle"

In [8]:
with open(os.path.join(DATA_DIR, "BC_0000_100.pkl"), 'rb') as f:
    data = pickle.load(f)

In [21]:
obs_p = [data['obs'][0][3]]

In [22]:
visualization.save_pointclouds(obs_p, savedir='.')

MoviePy - Building file ./point_cloud_segm.gif with imageio.


                                                  

In [45]:
pcl = obs_p[0]

i_tool = np.where(pcl[:,3] > 0)[0]
tool_pts = pcl[i_tool]

In [46]:
tool_pts = torch.from_numpy(tool_pts)

In [59]:
sampled_tool_idxs = fps(tool_pts[:, :3], ratio=0.05, random_start=False)

In [60]:
sampled_tool_pts = tool_pts[sampled_tool_idxs]

In [61]:
visualization.save_pointclouds([sampled_tool_pts], savedir='.')

MoviePy - Building file ./point_cloud_segm.gif with imageio.


                                                  

In [66]:
tool_tip_pos = data['obs'][0][0][:3]

In [73]:
tool_tip_pos

array([-0.00986772,  0.54506862, -0.07499563])

In [70]:
gt_tool_pts = sampled_tool_pts[:, :3] - tool_tip_pos

In [71]:
gt_tool_pts[0]

tensor([-0.0045, -0.0117, -0.0070], dtype=torch.float64)

In [74]:
with open('100_tool_pts.pkl', 'wb') as f:
    pickle.dump(gt_tool_pts, f)

In [75]:
all_gt_tool_pts = sampled_tool_pts
all_gt_tool_pts[:, :3] -= tool_tip_pos

In [76]:
visualization.save_pointclouds([all_gt_tool_pts], savedir='.')

MoviePy - Building file ./point_cloud_segm.gif with imageio.


                                                  

In [82]:
max(gt_tool_pts[1])

tensor(0.1027, dtype=torch.float64)

In [92]:
class MixedMediaToolReducer:
    TOOL_DATA_PATH = "bc/100_tool_pts.pkl"
    ACTION_LOW  = np.array([ 0, 0, 0, -1, -1, -1])
    ACTION_HIGH = np.array([ 0, 0, 0,  1,  1,  1])
    DEG_TO_RAD = np.pi / 180.
    MAX_ROT_AXIS_ANG = (10. * DEG_TO_RAD)

    def __init__(self, args, action_repeat):
        assert args.reduce_tool_points
        self.tool_point_num = args.tool_point_num
        self.action_repeat = action_repeat

        self.MAX_ROT_AXIS_ANG /= action_repeat

        with open(self.TOOL_DATA_PATH, 'rb') as f:
            self.all_tool_points = pickle.load(f)

        # Sample tool points
        ratio = self.tool_point_num / 100
        sampled_idxs = fps(self.all_tool_points, ratio=ratio, random_start=False)
        self.tool_points = self.all_tool_points[sampled_idxs].detach().numpy()

        self.rotation = Quaternion()

        # Prep tool points for rotation
        self.vec_mat = np.zeros((self.tool_point_num, 4, 4), dtype=self.tool_points.dtype)
        self.vec_mat[:, 0, 1] = -self.tool_points[:, 0]
        self.vec_mat[:, 0, 2] = -self.tool_points[:, 1]
        self.vec_mat[:, 0, 3] = -self.tool_points[:, 2]

        self.vec_mat[:, 1, 0] = self.tool_points[:, 0]
        self.vec_mat[:, 1, 2] = -self.tool_points[:, 2]
        self.vec_mat[:, 1, 3] = self.tool_points[:, 1]

        self.vec_mat[:, 2, 0] = self.tool_points[:, 1]
        self.vec_mat[:, 2, 1] = self.tool_points[:, 2]
        self.vec_mat[:, 2, 3] = -self.tool_points[:, 0]

        self.vec_mat[:, 3, 0] = self.tool_points[:, 2]
        self.vec_mat[:, 3, 1] = -self.tool_points[:, 1]
        self.vec_mat[:, 3, 2] = self.tool_points[:, 0]

    def reset(self):
        self.rotation = Quaternion()

    def set_axis(self, axis):
        self.rotation = Quaternion(w=axis[3], x=axis[0], y=axis[1], z=axis[2])

    def step(self, act_raw):
        # act_raw: [x, y, z, rx, ry, rz]
        act_clip = np.clip(act_raw, a_min=self.ACTION_LOW, a_max=self.ACTION_HIGH)
        axis = act_clip[3:]

        dtheta = np.linalg.norm(act_clip[3:])
        if dtheta > self.MAX_ROT_AXIS_ANG:
            dtheta = dtheta * self.MAX_ROT_AXIS_ANG / np.sqrt(3)
        
        if dtheta == 0:
            axis = np.array([0., -1., 0.])

        axis = axis / np.linalg.norm(axis)

        for i in range(self.action_repeat):
            axis_world = self.rotation.rotate(axis)
            qt_rotate = Quaternion(axis=axis_world, angle=dtheta)
            self.rotation = qt_rotate * self.rotation

    def reduce_tool(self, obs, info):
        tool_idxs = np.where(obs[:, 3] == 1)[0]
        obs_notool = obs[len(tool_idxs):]

        tool_tip_pos = info[:3]

        # Rotate tool points
        global_rotation = self.rotation
        global_rotation._normalise()
        dqp = global_rotation.conjugate.q

        mid = np.matmul(self.vec_mat, dqp)
        mid = np.expand_dims(mid, axis=-1)

        rotated_tool_pts = global_rotation._q_matrix() @ mid
        rotated_tool_pts = rotated_tool_pts[:, 1:, 0]

        rotated_tool_pts += tool_tip_pos

        num_classes = obs.shape[1] - 3
        tool_onehot = np.zeros((self.tool_point_num, num_classes), dtype=obs.dtype)
        tool_onehot[:, 0] = 1

        tool_reduced = np.concatenate((rotated_tool_pts, tool_onehot), axis=1)
        return np.concatenate((tool_reduced, obs_notool), axis=0)

In [98]:
class Args:
    reduce_tool_points = True
    tool_point_num = 20
    
args = Args()

In [99]:
num_obs = len(data['obs']) - 1

raw_obs_p = []
reduced_obs_p = []

tool_reducer = MixedMediaToolReducer(args=args, action_repeat=8)
tool_reducer.reset()

for t in range(num_obs):
    obs = data['obs'][t]
    raw_obs_p.append(obs[3])
    reduced_obs = tool_reducer.reduce_tool(obs[3], info=obs[0])
    reduced_obs_p.append(reduced_obs)
    tool_reducer.step(data['act_raw'][t])

In [100]:
visualization.save_pointclouds(raw_obs_p, savedir='.', suffix="raw.gif")
visualization.save_pointclouds(reduced_obs_p, savedir='.', suffix="reduced.gif")

MoviePy - Building file ./raw.gif with imageio.


                                                            

MoviePy - Building file ./reduced.gif with imageio.


                                                            