In [1]:
from pytorch3d.ops import iterative_closest_point

In [2]:
dir_path = '/home/anning/Downloads'

In [3]:
import os

In [4]:
lh_source = os.path.join(dir_path, 'lh_pial_source.npz')

In [5]:
lh_target = os.path.join(dir_path, 'lh_pial_target.npz')

In [6]:
import numpy as np

In [7]:
with np.load(lh_source) as data:
    s_Vertices = data['Vertices']
    s_Faces = data['Faces']

In [8]:
with np.load(lh_target) as data:
    t_Vertices = data['Vertices']
    t_Faces = data['Faces']

In [9]:
import torch

In [10]:
s_Vertices = torch.from_numpy(s_Vertices[np.newaxis, :]).float().cuda()

In [11]:
t_Vertices = torch.from_numpy(t_Vertices[np.newaxis, :]).float().cuda()

In [12]:
R_init = torch.from_numpy(np.eye(3)[np.newaxis, :]).float().cuda()

In [13]:
T_init = torch.from_numpy(np.array([0.1, 0.1, 0.1]).reshape(1, 3)).float().cuda()

In [14]:
s_init = torch.from_numpy(np.array([1])).float().cuda()

In [15]:
init_trans = (R_init, T_init, s_init)

In [16]:
result = iterative_closest_point(s_Vertices, t_Vertices, init_trans, max_iterations=100)

In [17]:
result.converged

True

In [18]:
result.RTs.R, result.RTs.T

(tensor([[[ 1.0001e+00,  3.6875e-04, -1.2874e-03],
          [-3.8928e-04,  9.9998e-01,  6.0586e-04],
          [ 1.1253e-03, -6.8723e-04,  1.0000e+00]]], device='cuda:0'),
 tensor([[ -0.4752,  17.4667, -22.4932]], device='cuda:0'))

# apply

In [19]:
R = result.RTs.R.cpu()
T = result.RTs.T.cpu()

In [20]:
def apply_pcl_transformation(X_t, R, T, s=None):
    """
    Apply a batch of similarity/rigid transformations, parametrized with
    rotation `R`, translation `T` and scale `s`, to an input batch of
    point clouds `X`.
    """

    if s is not None:
        X_t = s[:, None, None] * X_t

    X_t = torch.bmm(X_t, R) + T[:, None, :]

    return X_t

In [21]:
with np.load(lh_source) as data:
    s_Vertices = data['Vertices']

In [22]:
s_Vertices = torch.from_numpy(s_Vertices[np.newaxis, :]).float()

In [24]:
s_Vertices_moved = apply_pcl_transformation(s_Vertices, R, T)

In [25]:
s_Vertices_moved = s_Vertices_moved.numpy()
s_Vertices_moved

array([[[ -9.599041 , -89.11702  , -13.905361 ],
        [-10.3878565, -89.212234 , -14.130046 ],
        [-11.276999 , -89.276375 , -14.43433  ],
        ...,
        [-21.52945  ,  59.976704 ,  16.089872 ],
        [-20.858063 ,  61.19194  ,  16.218588 ],
        [-20.876316 ,  61.109924 ,  16.470644 ]]], dtype=float32)

## save surface

In [26]:
import nibabel as nib

In [27]:
s_Vertices_moved.shape

(1, 152854, 3)

In [28]:
nib.freesurfer.write_geometry('lh_moved.pial', s_Vertices_moved.squeeze().astype(float), s_Faces.astype(int))

In [29]:
with np.load(lh_target) as data:
    t_Vertices = data['Vertices']
    t_Faces = data['Faces']


In [30]:
nib.freesurfer.write_geometry('lh_target.pial', t_Vertices.astype(float), t_Faces.astype(int))