In [1]:
from scipy.io import loadmat
import torch
import torch.nn.functional as F
import numpy as np
import time

from sdf.robot_sdf import RobotSdfCollisionNet

In [2]:
device = torch.device('cpu', 0)
tensor_args = {'device': device, 'dtype': torch.float32}
data = loadmat('../data-sampling/datasets/data_mesh.mat')['dataset']
print(data.shape)

(4950000, 19)


In [3]:
L1 = 0
L2 = int(1.0 * data.shape[0])
n_size = L2
train_ratio = 0.98
test_ratio = 0.01
val_ratio = 1 - train_ratio - test_ratio
idx_train = np.arange(0, int(n_size * train_ratio))
idx_val = np.arange(idx_train[-1] + 1, int(n_size * (train_ratio + test_ratio)))
idx_test = np.arange(idx_val[-1] + 1, int(n_size))

print("Number of train data      : ", idx_train.shape[0])
print("Number of validation data : ", idx_val.shape[0])
print("Number of test data       : ", idx_test.shape[0])

Number of train data      :  4851000
Number of validation data :  49500
Number of test data       :  49500


In [4]:
x = torch.Tensor(data[L1:L2, 0:10]).to(device, dtype=torch.float16)
y = 100 * torch.Tensor(data[L1:L2, 10:]).to(device, dtype=torch.float16)
dof = x.shape[1]
s = 256
n_layers = 5
batch_size = 50000
skips = []
# fname = 'sdf_%dx%d_mesh_%d_drop_%1f.pt'%(s,n_layers,batch_size, p)
fname = 'model/sdf_%dx%d_mesh_%d.pt'%(s,n_layers,batch_size)
if skips == []:
    n_layers-=1
nn_model = RobotSdfCollisionNet(in_channels=dof, out_channels=y.shape[1], layers=[s] * n_layers, skips=skips, dropout_ratio=0)
nn_model.load_weights(fname, tensor_args)
nn_model.model.to(**tensor_args)
model = nn_model.model

Weights loaded!


In [5]:
x_test = x[idx_test, :]
y_test = y[idx_test, :]

In [6]:
x_tmp = torch.from_numpy( np.array([x_test[0,:].cpu().detach().numpy()], dtype = np.float32) )
print(x_tmp)

tensor([[-2.3340,  1.2070, -2.9668, -2.1777,  1.6318,  3.2090,  0.2642,  0.0271,
         -0.0393,  0.1138]])


In [7]:
tic=time.time()
y_pred, j_pred, _ = nn_model.compute_signed_distance_wgrad(x_tmp)
toc=time.time()
print(y_pred)
print(j_pred)
print(toc-tic)

tensor([[-0.6323,  3.0472, 16.8176, 30.4063, 42.7080, 53.2352, 62.3256, 67.8052,
         69.6175]])
tensor([[[ 6.0892e-01,  1.1861e+00,  1.6857e+00, -6.7983e-01,  5.6590e-01,
           1.4291e+00,  3.1498e+00,  3.6062e+00,  4.3735e+00],
         [ 3.4102e-01, -6.3736e-01, -6.3747e-01, -1.4858e+01, -1.3252e+01,
          -1.0406e+01, -4.5769e+00, -1.5499e+00,  1.5037e-01],
         [-6.7071e-01, -4.7712e-01, -5.1804e-01, -1.4064e+00, -7.2312e-02,
           4.4660e-02,  2.2492e+00,  3.3477e+00,  4.6831e+00],
         [-7.2216e-02, -6.1670e-01, -2.8209e-01,  1.2861e-01, -3.5333e-01,
           1.0201e+01,  2.1503e+01,  2.7276e+01,  2.8689e+01],
         [ 2.7833e-01,  2.7367e-01, -6.4873e-02, -6.8343e-02,  2.7827e-01,
           5.4586e-01, -3.9120e+00, -6.6616e+00, -7.3683e+00],
         [-6.3341e-02, -2.7395e-01,  7.8017e-02, -2.0242e-01, -2.9540e-01,
          -4.4987e-01, -4.4465e+00, -1.0369e+01, -1.5402e+01],
         [ 2.6977e-02, -3.3894e-02,  2.4099e-02,  1.2483e-02,  2.0451e-

In [8]:
import libNJSDF_FUN as NJSDF_FUN

NJSDF_FUN.setNeuralNetwork()
NJSDF_FUN.setNetworkInput(x_test[0,:].cpu().detach().numpy())
tic_=time.time()
g, g_d = NJSDF_FUN.calculateMlpOutput()
toc_=time.time()
print(g)
print(g_d)
print(toc_-tic_)

Time:0.00226
[-0.63233996  3.04721912 16.81759546 30.4063028  42.70804989 53.23517778
 62.32563993 67.8052121  69.61747842]
[[ 6.08922382e-01  3.41015624e-01 -6.70708854e-01 -7.22157531e-02
   2.78330287e-01 -6.33406391e-02  2.69774710e-02  5.70202890e+01
  -5.73525671e+01  1.71236630e+01]
 [ 1.18611534e+00 -6.37357253e-01 -4.77119849e-01 -6.16700199e-01
   2.73672989e-01 -2.73947025e-01 -3.38935612e-02  8.19642930e+00
  -7.61820669e+00 -9.84135676e+01]
 [ 1.68570034e+00 -6.37463583e-01 -5.18039113e-01 -2.82088147e-01
  -6.48730874e-02  7.80169796e-02  2.40989082e-02  3.18349964e+00
  -8.38566770e+00 -1.10620462e+02]
 [-6.79832654e-01 -1.48576436e+01 -1.40643261e+00  1.28607936e-01
  -6.83430664e-02 -2.02425094e-01  1.24827414e-02  3.51567358e+01
   4.59118175e+01 -7.88306091e+01]
 [ 5.65901156e-01 -1.32520353e+01 -7.23121356e-02 -3.53331555e-01
   2.78271315e-01 -2.95402162e-01  2.04507789e-01  2.38998294e+01
   3.59312246e+01 -8.79325995e+01]
 [ 1.42907867e+00 -1.04056160e+01  4.4659