In [None]:
from DataLoader import FaceLandmarkDataset
from se3.model import SE3Transformer, SE3UnetV2
import torch
import numpy as np

test_set = FaceLandmarkDataset(
    preprocessing="none",                         #choices=['icp', 'spatial_transformer', 'none']
    break_ds_with="none",        #choices=['rotation', 'translation', 'rotation_translation', 'none']
    split="Test",
    ds_path= "./Facescape",
    category="Neutral",
    references_pointclouds_icp_path="./Preprocessing/reference_pointclouds_for_icp",
    reduce_pointcloud_to=1024
)

se3trs = SE3Transformer(
    num_layers=2,
    num_channels=5,
    num_degrees=3,
    div=1,
    n_heads=1,
    si_m='1x1',
    si_e='att',
    x_ij= 'add'
)

unet2 = SE3UnetV2(
    n_layers=1,
    si_m="1x1",
    si_e="att",
    in_features=3,
    hidden_channels=5,  # ref paper 5
    out_features=68,
    pooling_ratio=0.35,
    aggr="sum"
)

In [None]:
def get_rotation():
    M = np.random.randn(3, 3)
    Q, __ = np.linalg.qr(M)
    return Q

rot = get_rotation()
traslation = torch.rand((1,3), dtype=torch.float32) * 10

In [None]:
x = test_set.faces[0].detach().clone()

test_set.faces[1] = (x @ rot).to(torch.float32)
test_set.faces[2] = x + traslation
test_set.faces[3] = (x @ rot).to(torch.float32) + traslation

In [None]:
from Preprocessing.procrustes_icp import visualize_two_pointclouds

visualize_two_pointclouds(test_set.faces[0], test_set.faces[1])
visualize_two_pointclouds(test_set.faces[0], test_set.faces[2])
visualize_two_pointclouds(test_set.faces[0], test_set.faces[3])

In [None]:
G, y = test_set[0]
with torch.no_grad():
    y_hat = se3trs(G)

y_hat = y_hat.view(-1, 3)

G_rotated, y_rotated = test_set[1]
with torch.no_grad():
    y_hat_rotated = se3trs(G_rotated)

y_hat_rotated = y_hat_rotated.view(-1, 3)

G_traslated, y_traslated = test_set[2]
with torch.no_grad():
    y_hat_traslated = se3trs(G_traslated)

y_hat_traslated = y_hat_traslated.view(-1, 3)

G_roto_traslated, y_roto_traslated = test_set[3]
with torch.no_grad():
    y_hat_roto_traslated = se3trs(G_roto_traslated)

y_hat_roto_traslated = y_hat_roto_traslated.view(-1, 3)
    

print("###### SE3 TRANSFORMER UNIT TEST ######\n")
print("###### ROTATION TEST ######")
print((y_hat @ rot).to(torch.float32))
print(y_hat_rotated)
print("###########################\n")

print("###### TRASLATION TEST ######")
print(y_hat + traslation)
print(y_hat_traslated)
print("#############################\n")

print("###### ROTO-TRASLATION TEST ######")
print((y_hat @ rot).to(torch.float32) + traslation)
print(y_hat_roto_traslated)
print("##################################\n")

In [None]:
G, y = test_set[0]
with torch.no_grad():
    y_hat = unet2(G, 'v', 1)

G_rotated, y_rotated = test_set[1]
with torch.no_grad():
    y_hat_rotated = unet2(G_rotated, 'v', 1)

G_traslated, y_traslated = test_set[2]
with torch.no_grad():
    y_hat_traslated = unet2(G_traslated, 'v', 1)

G_roto_traslated, y_roto_traslated = test_set[3]
with torch.no_grad():
    y_hat_roto_traslated = unet2(G_roto_traslated, 'v', 1)
    

print("###### SE3 U-NET UNIT TEST ######\n")
print("###### ROTATION TEST ######")
print((y_hat @ rot).to(torch.float32))
print(y_hat_rotated)
print("###########################\n")

print("###### TRASLATION TEST ######")
print(y_hat + traslation)
print(y_hat_traslated)
print("#############################\n")

print("###### ROTO-TRASLATION TEST ######")
print((y_hat @ rot).to(torch.float32) + traslation)
print(y_hat_roto_traslated)
print("##################################\n")