In [3]:
from pathlib import Path
import logging

import torchcpd
import pyvista as pv
import torch
from tqdm import tqdm

vtk_dir = Path.cwd().parent / "output_ssm_vtk"
aligned_vtk_dir = Path.cwd().parent / "output_ssm_vtk_aligned"
volume_points_cloud_dir = Path.cwd().parent / "output_ssm_vtk_volume"
surface_vtk_files: dict[str: list[Path]] = {
    case_dir.name: sorted(case_dir.glob("*.vtk")) for case_dir in vtk_dir.iterdir() if case_dir.is_dir()
}
print(len(surface_vtk_files))

logging.basicConfig(
    filename='aligning_errors.log',
    level=logging.ERROR,
    format='%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s',
    encoding='utf-8'
)

54


In [None]:
def align_surface_rigid(moving_surface: pv.PolyData, moving_point_cloud: pv.PolyData, fixed_point_cloud: pv.PolyData) -> pv.PolyData:
    """Align two surfaces using rigid transformation"""
    source_points = moving_point_cloud.points
    target_points = fixed_point_cloud.points
    res = moving_surface.copy()
    reg = torchcpd.RigidRegistration(X=target_points[::20], Y=source_points[::20], device='cuda', scale=False)
    _, (s, R, t) = reg.register()
    # handel nan in translation:
    if torch.isnan(t).any():
        logging.error(f"NaN in translation, vtk_file: {mov_vtk_file}")
        t = torch.zeros(3, device='cuda')

    res.points = reg.transform_point_cloud(torch.tensor(moving_surface.points, device='cuda')).cpu().numpy()
    return res

fix_case_name = sorted(surface_vtk_files.keys())[0]
fix_case_files = surface_vtk_files[fix_case_name].copy()
del surface_vtk_files[fix_case_name]
print(f"{fix_case_name=}")

for phase in tqdm(range(10), desc="Processing phases"):
    fix_vtk_file = fix_case_files[phase]
    fix_surface = pv.read(fix_vtk_file)
    new_fix_vtk_file = aligned_vtk_dir / fix_vtk_file.relative_to(vtk_dir)
    new_fix_vtk_file.parent.mkdir(parents=True, exist_ok=True)
    fix_surface.save(new_fix_vtk_file)
    fix_point_cloud = pv.read(volume_points_cloud_dir / fix_vtk_file.relative_to(vtk_dir))
    for case_name, case_files in tqdm(surface_vtk_files.items(), desc="Processing cases"):
        mov_vtk_file = case_files[phase]
        mov_surface = pv.read(mov_vtk_file)
        mov_point_cloud = pv.read(volume_points_cloud_dir / mov_vtk_file.relative_to(vtk_dir))
        mov_surface = align_surface_rigid(
            moving_surface=mov_surface,
            moving_point_cloud=mov_point_cloud,
            fixed_point_cloud=fix_point_cloud
        )
        new_vtk_file = aligned_vtk_dir / mov_vtk_file.relative_to(vtk_dir)
        new_vtk_file.parent.mkdir(parents=True, exist_ok=True)
        mov_surface.save(new_vtk_file)
    

fix_case_name='female_pt106'


Processing phases:   0%|          | 0/10 [00:00<?, ?it/s]

Processing cases: 100%|██████████| 53/53 [03:12<00:00,  3.63s/it]
Processing cases:  38%|███▊      | 20/53 [01:17<02:07,  3.87s/it]]
Processing phases:  10%|█         | 1/10 [04:29<40:28, 269.83s/it]


KeyboardInterrupt: 