In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append('/home/wtim/neural-capability-maps/')

In [3]:
import torch
from tabulate import tabulate
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation

import neural_capability_maps.dataset.se3 as se3
import neural_capability_maps.dataset.so3 as so3
from neural_capability_maps.dataset.morphology import sample_morph
from neural_capability_maps.dataset.capability_map import sample_capability_map
from neural_capability_maps.dataset.kinematics import analytical_inverse_kinematics

from neural_capability_maps.logger import binary_confusion_matrix
from neural_capability_maps.visualisation import visualise_predictions

In [4]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f8a1686b530>

In [5]:
morphs = sample_morph(10, 6, True)

In [6]:
discretisation_labels = []
discretisation_cell_indices = []
for morph in morphs[-1:]:
    cell_indices, labels = sample_capability_map(morph, 100_000)

    discretisation_labels += [labels.bool()]
    discretisation_cell_indices += [cell_indices]

[auto-batch] est. bytes/sample: 14149 (13.82 KiB), free: 34.56 GiB, safety: 0.5, batch_size: 1311236


In [7]:
tp = []
tn = []
fp = []
fn = []
acc = []
ground_truth = []
poses = []
reachable = []
joints = []
for label, cell_indices in zip(discretisation_labels, discretisation_cell_indices):
    ground_truth += [torch.zeros_like(label)]
    poses += [torch.zeros(label.shape[0], 4, 4)]
    joints += [torch.zeros(label.shape[0], morphs[-1].shape[0], 1)]
    for i in range(0, label.shape[0], 100_000):
        poses[-1][i:i + 100_000] = se3.cell(cell_indices[i:i + 100_000])
        joint, manipulability = analytical_inverse_kinematics(morphs[-1].double(), poses[-1][i:i + 100_000].double())
        ground_truth[-1][i:i + 100_000] = manipulability.cpu() != -1
        joints[-1][i:i+100_000] = joint

    (true_positives, false_negatives), (false_positives, true_negatives) = binary_confusion_matrix(label,
                                                                                                   ground_truth[-1])
    accuracy = (ground_truth[-1] == label).sum() / label.shape[0] * 100
    tp += [true_positives]
    tn += [true_negatives]
    fp += [false_positives]
    fn += [false_negatives]
    acc += [accuracy]
    reachable += [ground_truth[-1].sum() / ground_truth[-1].shape[0] * 100]

headers = ["True Positives", "True Negatives", "False Positives", "False Negatives", "Accuracy", "Reachable"]
table = list(zip(tp, tn, fp, fn, acc, reachable))
print(tabulate(table, headers=headers, floatfmt=".2f"))

  True Positives    True Negatives    False Positives    False Negatives    Accuracy    Reachable
----------------  ----------------  -----------------  -----------------  ----------  -----------
           99.98             91.09               8.91               0.02       92.01        10.39


In [10]:
visualise_predictions([morphs[-1]], poses, discretisation_labels, ground_truth)

In [None]:
num_geodesics = 3
line_samples = 5000

directions = se3.random(num_geodesics)

batched_origin = torch.eye(4).repeat(num_geodesics, 1, 1)
axis_angle = so3.to_vector(directions[:, :3, :3])
flat_pose = torch.cat([axis_angle, directions[:, :3, 3]], dim=1) / se3.distance(directions, batched_origin)
flat_pose /= flat_pose[:, 3:].norm(dim=1, keepdim=True)

t = torch.arange(0, 1, 1 / line_samples).repeat(num_geodesics, 1)
lines = t.unsqueeze(2) * flat_pose.unsqueeze(1)

line_poses = torch.eye(4).repeat(num_geodesics, line_samples, 1, 1)
line_poses[:, :, :3, :3] = Rotation.from_rotvec(lines[:, :, :3].reshape(-1, 3)).as_matrix().reshape(
    *lines[:, :, :3].shape[0:2], 3, 3)
line_poses[:, :, :3, 3] = lines[:, :, 3:]

In [None]:
manipulability = analytical_inverse_kinematics(morphs[-1], line_poses.view(-1, 4, 4))[1].reshape(num_geodesics,
                                                                                                 line_samples)
ground_truth_line = manipulability != -1

In [None]:
fig, axes = plt.subplots(num_geodesics, figsize=(20, 5))
for i in range(num_geodesics):
    axes[i].plot(ground_truth_line[i], label="Ground-Truth")
    axes[i].plot(line_pred[i] * 0.9, label="Classifier", alpha=0.7)  # one for dataset and one for model
    axes[i].legend(loc="right", bbox_to_anchor=(1.4, 0.5))
plt.show()