In [8]:
import os
from pathlib import Path
import shutil
import time
from typing import Dict, List, Union
import skill_generator.models.skill_generator as model_sg
import hulc
import torch

In [9]:
def get_all_checkpoints(experiment_folder: Path) -> List:
    if experiment_folder.is_dir():
        checkpoint_folder = experiment_folder / "saved_models"
        if checkpoint_folder.is_dir():
            checkpoints = sorted(Path(checkpoint_folder).iterdir(), key=lambda chk: chk.stat().st_mtime)
            if len(checkpoints):
                return [chk for chk in checkpoints if chk.suffix == ".ckpt"]
    return []

In [10]:
def get_last_checkpoint(experiment_folder: Path) -> Union[Path, None]:
    # return newest checkpoint according to creation time
    checkpoints = get_all_checkpoints(experiment_folder)
    if len(checkpoints):
        return checkpoints[-1]
    return None

In [22]:
def _sample(mu: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    eps = torch.randn(*mu.size()).to(mu)
    return mu + scale * eps

In [23]:
def _check_direction(dire):
    p = torch.clone(dire)
    n = torch.clone(dire)
    p[dire < 0.] = 0.
    n[dire > 0.] = 0.
    n = torch.abs(n)
    return p, n

In [24]:
def skill_classifier(actions, scale=[1.2, 1.3, 1.0]):
    gripper_energy = 0.
    _, T, _ = actions.shape
    diff = torch.sum(actions[:, :, :6], dim=1)
    for i in range(T - 1):
        gripper_energy += abs(actions[:, i + 1, 6] - actions[:, i, 6])
    energy = torch.abs(diff)

    translation_right, translation_left = _check_direction(diff[:, 0])
    translation_forward, translation_backward = _check_direction(diff[:, 1])
    translation_up, translation_down = _check_direction(diff[:, 2])

    rotation = (energy[:, 3] + energy[:, 4] + energy[:, 5]) / 3
    gripper = gripper_energy

    translation_right /= scale[0]
    translation_left /= scale[0]
    translation_forward /= scale[0]
    translation_backward /= scale[0]
    translation_up /= scale[0]
    translation_down /= scale[0]

    rotation /= scale[1]
    gripper /= scale[2]

    t = torch.stack([translation_left, translation_right, translation_forward, translation_backward, translation_up, translation_down, rotation, gripper], dim=-1)
    B, _ = t.shape
    skill_types = torch.argmax(t, dim=-1)
    return skill_types

In [25]:
batch = 10000
# load_checkpoint
sg_chk_path = './sg_runs/2022-12-09/21-34-22'
sg_chk_path = Path(hulc.__file__).parent.parent / sg_chk_path
chk = get_last_checkpoint(sg_chk_path)
skill_generator = getattr(model_sg, 'SkillGenerator').load_from_checkpoint(chk.as_posix())
skill_generator.freeze()
prior_locator = skill_generator.prior_locator.eval()
action_decoder = skill_generator.decoder.eval()

priors = prior_locator(repeat=batch)
skill_len = torch.tensor(5)

tl_mu = priors['p_mu'][:,0,:]
tl_scale = priors['p_scale'][:,0,:]
tl_sampled = _sample(tl_mu, tl_scale)

tr_mu = priors['p_mu'][:,1,:]
tr_scale = priors['p_scale'][:,1,:]
tr_sampled = _sample(tr_mu, tr_scale)

tf_mu = priors['p_mu'][:,2,:]
tf_scale = priors['p_scale'][:,2,:]
tf_sampled = _sample(tf_mu, tf_scale)

tb_mu = priors['p_mu'][:,3,:]
tb_scale = priors['p_scale'][:,3,:]
tb_sampled = _sample(tb_mu, tb_scale)

tu_mu = priors['p_mu'][:,4,:]
tu_scale = priors['p_scale'][:,4,:]
tu_sampled = _sample(tu_mu, tu_scale)

td_mu = priors['p_mu'][:,5,:]
td_scale = priors['p_scale'][:,5,:]
td_sampled = _sample(td_mu, td_scale)

r_mu = priors['p_mu'][:,6,:]
r_scale = priors['p_scale'][:,6,:]
r_sampled = _sample(r_mu, r_scale)

g_mu = priors['p_mu'][:,7,:]
g_scale = priors['p_scale'][:,7,:]
g_sampled = _sample(g_mu, g_scale)

tl_actions = action_decoder(tl_sampled, skill_len.repeat(batch))
tr_actions = action_decoder(tr_sampled, skill_len.repeat(batch))
tf_actions = action_decoder(tf_sampled, skill_len.repeat(batch))
tb_actions = action_decoder(tb_sampled, skill_len.repeat(batch))
tu_actions = action_decoder(tu_sampled, skill_len.repeat(batch))
td_actions = action_decoder(td_sampled, skill_len.repeat(batch))

r_actions = action_decoder(r_sampled, skill_len.repeat(batch))
g_actions = action_decoder(g_sampled, skill_len.repeat(batch))

rate_tl = torch.sum(skill_classifier(tl_actions) == 0) / batch
rate_tr = torch.sum(skill_classifier(tr_actions) == 1) / batch
rate_tf = torch.sum(skill_classifier(tf_actions) == 2) / batch
rate_tb = torch.sum(skill_classifier(tb_actions) == 3) / batch
rate_tu = torch.sum(skill_classifier(tu_actions) == 4) / batch
rate_td = torch.sum(skill_classifier(td_actions) == 5) / batch

rate_r = torch.sum(skill_classifier(r_actions) == 6) / batch
rate_g = torch.sum(skill_classifier(g_actions) == 7) / batch

print('left translation rate: ', rate_tl)
print('right translation rate: ', rate_tr)
print('forward translation rate: ', rate_tf)
print('backward translation rate: ', rate_tb)
print('up translation rate: ', rate_tu)
print('down translation rate: ', rate_td)

print('rotation rate: ', rate_r)
print('grasp rate: ', rate_g)

left translation rate:  tensor(0.7897)
right translation rate:  tensor(0.8021)
forward translation rate:  tensor(0.8540)
backward translation rate:  tensor(0.8981)
up translation rate:  tensor(0.7459)
down translation rate:  tensor(0.7321)
rotation rate:  tensor(0.5967)
grasp rate:  tensor(0.6919)
