In [1]:
from DataLoader import FaceLandmarkDataset
from se3.model import SE3Transformer, SE3UnetV2
import torch
import numpy as np

test_set = FaceLandmarkDataset(
    preprocessing="none",                         #choices=['icp', 'spatial_transformer', 'none']
    break_ds_with="none",        #choices=['rotation', 'translation', 'rotation_translation', 'none']
    split="Test",
    ds_path= "./Facescape",
    category="Neutral",
    references_pointclouds_icp_path="./Preprocessing/reference_pointclouds_for_icp",
    reduce_pointcloud_to=1024
)

# From the original paper (N body experiment)
se3trs = SE3Transformer(
    num_layers=2,
    num_channels=5,
    num_degrees=3,
    div=1,
    n_heads=1,
    si_m='1x1',
    si_e='att',
    x_ij= 'add'
)

unet2 = SE3UnetV2(
    n_layers=1,
    si_m="1x1",
    si_e="att",
    in_features=3,
    hidden_channels=5,  # ref paper 5
    out_features=68,
    pooling_ratio=0.35,
    aggr="sum"
)

  from .autonotebook import tqdm as notebook_tqdm


Loading dataset...
Path file:  ./Facescape/Test/test_neutral.npy
Preprocessing: none
Dataset: Facescape
Category: Neutral


Len dataset:  torch.Size([167, 1024, 3])
Face:  torch.Size([167, 1024, 3])
Landmark:  torch.Size([167, 68, 3])
Heatmaps:  torch.Size([167, 1024, 68])
Scale:  torch.Size([167])


In [2]:
def get_rotation():
    M = np.random.randn(3, 3)
    Q, __ = np.linalg.qr(M)
    return Q

rot = get_rotation()
traslation = torch.rand((1,3), dtype=torch.float32) * 10

In [3]:
x = test_set.faces[0].detach().clone()
#x = x - torch.mean(x, dim=0, keepdim=True)

test_set.faces[1] = (x @ rot).to(torch.float32)
test_set.faces[2] = x + traslation
test_set.faces[3] = (x @ rot).to(torch.float32) + traslation

In [4]:
from Preprocessing.procrustes_icp import visualize_two_pointclouds

visualize_two_pointclouds(test_set.faces[0], test_set.faces[1])
visualize_two_pointclouds(test_set.faces[0], test_set.faces[2])
visualize_two_pointclouds(test_set.faces[0], test_set.faces[3])

In [5]:
G, y = test_set[0]
with torch.no_grad():
    y_hat = se3trs(G)

y_hat = y_hat.view(-1, 3)

G_rotated, y_rotated = test_set[1]
with torch.no_grad():
    y_hat_rotated = se3trs(G_rotated)

y_hat_rotated = y_hat_rotated.view(-1, 3)

G_traslated, y_traslated = test_set[2]
with torch.no_grad():
    y_hat_traslated = se3trs(G_traslated)

y_hat_traslated = y_hat_traslated.view(-1, 3)

G_roto_traslated, y_roto_traslated = test_set[3]
with torch.no_grad():
    y_hat_roto_traslated = se3trs(G_roto_traslated)

y_hat_roto_traslated = y_hat_roto_traslated.view(-1, 3)

print("###### SE3 TRANSFORMER UNIT TEST ######\n")
print("###### ROTATION TEST ######")
print((y_hat @ rot).to(torch.float32))
print(y_hat_rotated)
print("###########################\n")

print("###### TRASLATION TEST ######")
print(y_hat + traslation)
print(y_hat_traslated)
print("#############################\n")

print("###### ROTO-TRASLATION TEST ######")
print((y_hat @ rot).to(torch.float32) + traslation)
print(y_hat_roto_traslated)
print("##################################\n")

###### SE3 TRANSFORMER UNIT TEST ######

###### ROTATION TEST ######
tensor([[-0.2196, -0.4883, -0.3498],
        [-0.3742, -0.5398, -0.0789],
        [ 0.5065,  0.4549, -0.5887],
        ...,
        [-0.2020, -0.5756, -0.3895],
        [-0.3535,  0.0024,  0.0227],
        [-0.1155, -0.1723, -0.2515]])
tensor([[-0.2196, -0.4883, -0.3498],
        [-0.3742, -0.5398, -0.0789],
        [ 0.5065,  0.4549, -0.5887],
        ...,
        [-0.2020, -0.5756, -0.3895],
        [-0.3535,  0.0024,  0.0227],
        [-0.1155, -0.1723, -0.2515]])
###########################

###### TRASLATION TEST ######
tensor([[9.1579, 7.0865, 2.0742],
        [9.4720, 7.0708, 2.1064],
        [8.3376, 6.6750, 2.8690],
        ...,
        [9.1389, 7.1650, 2.0196],
        [9.4041, 6.6403, 2.4453],
        [9.1010, 6.8896, 2.3541]])
tensor([[ 1.1412,  1.0202,  0.4010],
        [-1.5249,  0.6501, -0.1419],
        [-1.4985, -0.1763,  0.7924],
        ...,
        [ 0.5050, -0.5557,  0.8794],
        [ 0.4165,  1.

In [6]:
G, y = test_set[0]
with torch.no_grad():
    y_hat = unet2(G, 'v', 1)

G_rotated, y_rotated = test_set[1]
with torch.no_grad():
    y_hat_rotated = unet2(G_rotated, 'v', 1)

G_traslated, y_traslated = test_set[2]
with torch.no_grad():
    y_hat_traslated = unet2(G_traslated, 'v', 1)

G_roto_traslated, y_roto_traslated = test_set[3]
with torch.no_grad():
    y_hat_roto_traslated = unet2(G_roto_traslated, 'v', 1)
    

print("###### SE3 U-NET UNIT TEST ######\n")
print("###### ROTATION TEST ######")
print((y_hat @ rot).to(torch.float32))
print(y_hat_rotated)
print("###########################\n")

print("###### TRASLATION TEST ######")
print(y_hat + traslation)
print(y_hat_traslated)
print("#############################\n")

print("###### ROTO-TRASLATION TEST ######")
print((y_hat @ rot).to(torch.float32) + traslation)
print(y_hat_roto_traslated)
print("##################################\n")

###### SE3 U-NET UNIT TEST ######

###### ROTATION TEST ######
tensor([[[ 0.2275,  0.6757,  0.8214],
         [ 1.9269,  0.0536,  0.8404],
         [ 1.7290,  0.1810,  1.7571],
         ...,
         [ 1.3131, -0.0295, -0.1818],
         [ 2.0351, -0.6807,  1.3047],
         [ 0.5918,  1.0612,  0.9091]]])
tensor([[[ 0.2275,  0.6757,  0.8214],
         [ 1.9269,  0.0536,  0.8404],
         [ 1.7290,  0.1810,  1.7571],
         ...,
         [ 1.3131, -0.0295, -0.1818],
         [ 2.0351, -0.6807,  1.3047],
         [ 0.5918,  1.0612,  0.9091]]])
###########################

###### TRASLATION TEST ######
tensor([[[9.5679, 6.4730, 3.6175],
         [8.8397, 7.9203, 4.4239],
         [9.6525, 7.7933, 4.8919],
         ...,
         [8.3617, 7.5572, 3.3904],
         [9.3454, 8.6347, 4.4036],
         [9.3464, 6.3592, 4.0940]]])
tensor([[[ 0.8144, -0.3111,  0.3754],
         [ 2.3573,  1.0549, -0.7644],
         [ 0.1172,  2.4850,  0.9675],
         ...,
         [ 1.0059, -0.0228, -0.3944]