In [5]:
import glob
import h5py
import numpy as np 
import os
import pickle
import torch 

from omegaconf import DictConfig, OmegaConf
# from tactile_learning.deployment.load_models import load_model
from tactile_learning.datasets.tactile import TactileImage

In [55]:
# Helper script to load models
import numpy as np
import os
import torch
import torch.utils.data as data 
import torchvision.transforms as T

from collections import OrderedDict
from tqdm import tqdm 
from torch.nn.parallel import DistributedDataParallel as DDP

from holobot.robot.allegro.allegro_kdl import AllegroKDL
from tactile_learning.models.custom import TactileJointLinear, TactileImageEncoder

def load_model(cfg, device, model_path):
    # Initialize the model
    if cfg.agent_type == 'bc':
        model = TactileJointLinear(
            input_dim=cfg.tactile_info_dim + cfg.joint_pos_dim,
            output_dim=cfg.joint_pos_dim,
            hidden_dim=cfg.hidden_dim
        )
    elif cfg.agent_type == 'byol': # load the encoder
        model = TactileImageEncoder(
            in_channels=cfg.encoder.in_channels,
            out_dim=cfg.encoder.out_dim
        )
    # print('model: {}'.format(model))
    state_dict = torch.load(model_path)
    
    # Modify the state dict accordingly - this is needed when multi GPU saving was done
    new_state_dict = modify_multi_gpu_state_dict(state_dict)
    
    if cfg.agent_type == 'byol':
        new_state_dict = modify_byol_state_dict(new_state_dict)

    # Load the new state dict to the model 
    model.load_state_dict(new_state_dict)

    # Turn it into DDP - it was saved that way 
    model = DDP(model.to(device), device_ids=[0])

    return model

def modify_multi_gpu_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]
        new_state_dict[name] = v 
    return new_state_dict

def modify_byol_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if 'encoder.net' in k:
            name = k[12:] # Everything after encoder.net
            new_state_dict[name] = v
    # print(new_state_dict['encoder.net'])
    return new_state_dict


In [72]:
# Script to deploy VINN with a saved encoder

# Get the out dir
out_dir = '/home/irmak/Workspace/tactile-learning/tactile_learning/out/2022.11.30/12-17_byol_bs_2048_joystick'

class DeployVINN:
    def __init__(
        self,
        out_dir,
        sensor_indices = (3,7),
        allegro_finger_indices = (0,1)
    ):
        # os.environ["MASTER_ADDR"] = "localhost"
        # os.environ["MASTER_PORT"] = "29505"

        # torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
        torch.cuda.set_device(0)

        self.sensor_indices = sensor_indices 
        self.allegro_finger_indices = [j for i in allegro_finger_indices for j in range(i*3,(i+1)*3)]

        self.device = torch.device('cuda:0')
        self.cfg = OmegaConf.load(os.path.join(out_dir, '.hydra/config.yaml'))
        self.data_path = self.cfg.data_dir
        model_path = os.path.join(out_dir, 'models/byol_encoder.pt')

        self.encoder = load_model(self.cfg, self.device, model_path)
        self.encoder.eval() 

        self.resize_transform = T.Resize((8, 8))

        tactile_values, allegro_tip_positions = self._load_data()
        self._get_all_representations(tactile_values, allegro_tip_positions)

        self.kdl_solver = AllegroKDL()

    def _load_data(self):
        roots = glob.glob(f'{self.data_path}/demonstration_*')
        roots = sorted(roots)

        self.tactile_indices = [] 
        self.allegro_indices = [] 
        self.allegro_action_indices = [] 
        self.allegro_actions = [] 
        tactile_values = [] 
        allegro_tip_positions = []

        for root in roots:
            # Load the indices
            with open(os.path.join(root, 'tactile_indices.pkl'), 'rb') as f:
                self.tactile_indices += pickle.load(f)
            with open(os.path.join(root, 'allegro_indices.pkl'), 'rb') as f:
                self.allegro_indices += pickle.load(f)
            with open(os.path.join(root, 'allegro_action_indices.pkl'), 'rb') as f:
                self.allegro_action_indices += pickle.load(f)

            # Load the data
            with h5py.File(os.path.join(root, 'allegro_fingertip_states.h5'), 'r') as f:
                # print(f['positions'][()].shape)
                allegro_tip_positions.append(f['positions'][()][:, self.allegro_finger_indices])
            with h5py.File(os.path.join(root, 'allegro_commanded_joint_states.h5'), 'r') as f:
                self.allegro_actions.append(f['positions'][()]) # Positions are to be learned - since this is a position control
            with h5py.File(os.path.join(root, 'touch_sensor_values.h5'), 'r') as f:
                tactile_values.append(f['sensor_values'][()][:,self.sensor_indices,:,:])

        # print(self.allegro_tip_positions[0].shape, self.tactile_values[0].shape)
        return tactile_values, allegro_tip_positions

    def _get_tactile_image(self, tactile_value):
        tactile_image = torch.FloatTensor(tactile_value)
        tactile_image = tactile_image.reshape((
            len(self.sensor_indices),  # Total number of sensors - (2,16,3)
            4, 
            4,
            -1
        ))
        # TODO: This will only work for this grid
        tactile_image = torch.concat((tactile_image[0], tactile_image[1]), dim=1)
        tactile_image = torch.permute(tactile_image, (2,0,1))

        return self.resize_transform(tactile_image)

    # tactile_values: (2,16,3)
    # allegro_tip_positions: (6,) - 3 values for each finger
    def _get_one_representation(self, tactile_values, allegro_tip_positions):
        # For each tactile value get the tactile image
        tactile_image = self._get_tactile_image(tactile_values).unsqueeze(dim=0)
        # print('tactile_image.shape: {}'.format(tactile_image.shape))
        tactile_repr = self.encoder(tactile_image)
        tactile_repr = tactile_repr.detach().cpu().numpy().squeeze() # Remove the axes with dimension 1
        # print('tactile_repr.shape: {}'.format(tactile_repr.shape))
        # It should be (64,)
        return np.concatenate((tactile_repr, allegro_tip_positions), axis=0)


    def _get_all_representations(
        self,
        tactile_values,
        allegro_tip_positions
    ):  
        print('Getting all representations')
        pbar = tqdm(total=len(self.tactile_indices))
        # For each tactile value and allegro tip position 
        # get one representation and add it to all representations
        repr_dim = self.cfg.encoder.out_dim + len(self.cfg.dataset.allegro_finger_indices) * 3 
        # print('repr_dim: {}'.format(repr_dim))
        self.all_representations = np.zeros((
            len(self.tactile_indices), repr_dim
        ))

        for index in range(len(self.tactile_indices)):
            demo_id, tactile_id = self.tactile_indices[index]
            _, allegro_tip_id = self.allegro_action_indices[index]

            tactile_value = tactile_values[demo_id][tactile_id] # This should be (2,16,3)
            # print('tactile_value.shape: {}'.format(tactile_value.shape))
            allegro_tip_position = allegro_tip_positions[demo_id][tactile_id] # This should be (6,)
            # print('allegro_tip_position.shape: {}'.format(allegro_tip_position.shape))
            representation = self._get_one_representation(
                tactile_value, 
                allegro_tip_position
            )
            self.all_representations[index, :] = representation[:]
            pbar.update(1)

        pbar.close()

    # tactile_values.shape: (16,15,3)
    # joint_state.shape: (16)
    def get_action(self, tactile_values, joint_state):
        # Get the allegro tip positions with kdl solver 
        fingertip_positions = self.kdl_solver.get_fingertip_coords(joint_state) # - fingertip position.shape: (12)

        # Get the tactile image from the tactile values
        curr_tactile_values = tactile_values[self.sensor_indices,:,:]
        curr_fingertip_position = fingertip_positions[self.allegro_finger_indices]

        print('curr_tactile_values.shape: {}, curr_fingertip_position.shape: {}'.format(
            curr_tactile_values.shape, curr_fingertip_position.shape
        ))

        assert curr_tactile_values.shape == (2,16,3) and curr_fingertip_position.shape == (6,)

        # Get the representation with the given tactile value
        curr_representation = self._get_one_representation(
            curr_tactile_values, 
            curr_fingertip_position
        )

        nn_id = self._get_knn_idxs(curr_representation, k=0)
        print('nn_id: {}'.format(nn_id))

        # Get the applied action at that id
        demo_id, action_id = self.allegro_action_indices[nn_id[0]]
        nn_action = self.allegro_actions[demo_id][action_id]

        print('nn_action: {}'.format(nn_action))

        return nn_action

    def _get_sorted_idxs(self, representation):
        l1_distances = self.all_representations - representation
        print('l1_distances.shape: {}'.format(l1_distances.shape))
        l2_distances = np.linalg.norm(l1_distances, axis = 1)

        sorted_idxs = np.argsort(l2_distances)
        return sorted_idxs

    def _get_knn_idxs(self, representation, k=0):
        sorted_idxs = self._get_sorted_idxs(representation)
        
        knn_idxs = sorted_idxs[:k+1]
        return knn_idxs


In [73]:
vinn = DeployVINN(out_dir)

Getting all representations


100%|██████████| 52870/52870 [00:27<00:00, 1914.73it/s]


In [74]:
x = np.random.rand(15,16,3)
y = np.random.rand(16)

In [75]:
action = vinn.get_action(x, y)

curr_tactile_values.shape: (2, 16, 3), curr_fingertip_position.shape: (6,)
l1_distances.shape: (52870, 70)
nn_id: [33830]
nn_action: [0.         0.09693163 0.03691709 0.21400803 0.         0.01564443
 0.05228145 0.22613142 0.08778772 0.16179268 0.07308225 0.09116157
 0.51118785 0.26664537 0.7438539  0.65285134]


In [71]:
x = np.random.rand(13)
x[:0]

array([], dtype=float64)