In [None]:
import pyvista as pv
pv.start_xvfb(wait=0)
pv.set_jupyter_backend('html')

import os
import sys
sys.path.append(os.path.join('..', '.'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from pytorch3d.structures import Meshes
from pytorch3d.io import load_objs_as_meshes, save_obj
from pytorch3d.ops import cubify, cot_laplacian, sample_points_from_meshes, knn_points, knn_gather, norm_laplacian, taubin_smoothing
from pytorch3d.loss import chamfer_distance
from pytorch3d.utils import ico_sphere

from torch_geometric.utils import degree, to_undirected, to_dense_adj, get_laplacian, add_self_loops
from torch_geometric.data import Data
# from torch_geometric.transforms import gdc
from torch_scatter import scatter

import numpy as np

import trimesh


from scipy.sparse.linalg import eigsh
from scipy.sparse import coo_matrix


from data_process.dataset_real_scaling import UKBB_dataset, MMWHS_dataset, ACDC_dataset, CCT48_dataset
from ops.graph_operators import NativeFeaturePropagation, LaplacianSmoothing


from tqdm import tqdm

from probreg import cpd

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle


import warnings
warnings.filterwarnings("ignore")

from GHD import GHD_config, GHDmesh, Normal_iterative_GHDmesh
from GHD.GHD_cardiac import GHD_Cardiac


from data_process.dataset_real_scaling import UKBB_dataset, MMWHS_dataset, ACDC_dataset, CCT48_dataset

from einops import rearrange, einsum, repeat

from pytorch3d.loss import chamfer_distance,mesh_laplacian_smoothing, mesh_normal_consistency, mesh_edge_loss


from losses import *
from ops.mesh_geometry import *

import pickle


In [None]:
base_shape_path = '../canonical_shapes/Standard_LV_2000.obj'
bi_ventricle_path = '../canonical_shapes/Standard_BiV.obj'

# base_shape_path = 'metadata/Standard_LV.obj'
# bi_ventricle_path = 'metadata/Standard_BiV.obj'

cfg = GHD_config(base_shape_path=base_shape_path,
            num_basis=7**2, mix_laplacian_tradeoff={'cotlap':1.0, 'dislap':0.1, 'stdlap':0.1},
            device='cuda:3',
            if_nomalize=True, if_return_scipy=True, 
            bi_ventricle_path=bi_ventricle_path)

paraheart = GHD_Cardiac(cfg) # 


# load initial orientation according to dataset

with open('../canonical_shapes/mmwhs_init_affine.pkl', 'rb') as f:
    initial_orientation = pickle.load(f)

R = initial_orientation[:3,:3].astype(np.float32)
T = initial_orientation[:3,3].astype(np.float32)

paraheart.R = matrix_to_axis_angle(torch.from_numpy(R).to(paraheart.device)).view(paraheart.R.shape)
paraheart.T = torch.from_numpy(T).to(paraheart.device).view(paraheart.T.shape)

In [3]:
root_path = os.path.dirname(os.path.realpath('..'))

root_path = os.path.join(root_path,'Dataset','MMWHS')

mmwhs_bi_lv = MMWHS_dataset(output_size=(128, 128, 128),dataset_path=root_path,
                            if_augment=False, label_value_list=[[205.,600.],[205.]],eps= 0.01, process_device=paraheart.device)


In [None]:
dataloader_mmwhs = DataLoader(mmwhs_bi_lv, batch_size=1, shuffle=False)

LaplacianSmoother = LaplacianSmoothing()

for i, data in enumerate(dataloader_mmwhs):
    if i<1:
        continue
    # if i not in [15]:
    #     continue
    # img = data['img'].to(device)
    seg_gt = data['seg_gt']

    seg_gt_bi = seg_gt[:,0,...].to(paraheart.device).float()
    seg_gt_lv = seg_gt[:,1,...].to(paraheart.device).float()


    window_size = data['window_size'].to(paraheart.device).float()

    mesh_gt_lv = cubify(seg_gt_lv, 0.5)
    mesh_gt_lv = mesh_gt_lv.update_padded((mesh_gt_lv.verts_padded()*window_size/200).float())
    # mesh_gt_lv = LaplacianSmoother.mesh_smooth(mesh_gt_lv, num_iterations=1)
    mesh_gt_lv = taubin_smoothing(mesh_gt_lv, 0.1, 0.5, num_iter=20)
    


    mesh_gt_bi = cubify(seg_gt_bi, 0.5)
    mesh_gt_bi = mesh_gt_bi.update_padded((mesh_gt_bi.verts_padded()*window_size/200).float())
    # mesh_gt_bi = LaplacianSmoother.mesh_smooth(mesh_gt_bi, num_iterations=1)
    mesh_gt_bi = taubin_smoothing(mesh_gt_bi, 0.1, 0.5, num_iter=20)

    bbox_lv = mesh_gt_lv.get_bounding_boxes()[0]*1.3
    # point cloud from the mask
    # bi-ventricle
    points_bi = torch.stack(torch.where(seg_gt_bi>0.5)[1:], dim=-1).float()
    points_bi = points_bi/(torch.tensor(seg_gt_bi.shape[-3:]).float().to(paraheart.device)-1)*2-1
    points_bi = points_bi[:,[2,1,0]]*window_size/200

    # left ventricle
    points_lv = torch.stack(torch.where(seg_gt_lv>0.5)[1:], dim=-1).float()
    points_lv = points_lv/(torch.tensor(seg_gt_lv.shape[-3:]).float().to(paraheart.device)-1)*2-1
    points_lv = points_lv[:,[2,1,0]]*window_size/200 

    # out of the LV
    points_outoflv = torch.stack(torch.where(seg_gt_lv<0.5)[1:], dim=-1).float()
    points_outoflv = points_outoflv/(torch.tensor(seg_gt_lv.shape[-3:]).float().to(paraheart.device)-1)*2-1
    points_outoflv = points_outoflv[:,[2,1,0]]*window_size/200

    points_outoflv_in_bbox = points_outoflv[(points_outoflv[:,0]>bbox_lv[0,0]) & (points_outoflv[:,0]<bbox_lv[0,1]) & (points_outoflv[:,1]>bbox_lv[1,0]) & (points_outoflv[:,1]<bbox_lv[1,1]) & (points_outoflv[:,2]>bbox_lv[2,0]) & (points_outoflv[:,2]<bbox_lv[2,1])]


    break

In [None]:
sample_num = 2000

mesh_gt_bi_sample = points_bi.detach().cpu().numpy()[np.random.choice(points_bi.shape[0], sample_num, replace=False)]
paraheart.global_registration_biv(mesh_gt_bi_sample)


sample_lv = points_lv[np.random.choice(points_lv.shape[0], sample_num, replace=False)]
paraheart.global_registration_lv(sample_lv.detach().cpu().numpy())

In [None]:
pl = pv.Plotter(notebook=True)
trimesh_gt_lv = trimesh.Trimesh(mesh_gt_lv.verts_packed().detach().cpu().numpy(), mesh_gt_lv.faces_packed().detach().cpu().numpy())
trimesh_gt_lv = pv.wrap(trimesh_gt_lv)
pl.add_mesh(trimesh_gt_lv, color='lightgreen', opacity=0.5)

# trimesh_gt_bi = trimesh.Trimesh(mesh_gt_bi.verts_packed().detach().cpu().numpy(), mesh_gt_bi.faces_packed().detach().cpu().numpy())
# trimesh_gt_bi = pv.wrap(trimesh_gt_bi)
# pl.add_mesh(trimesh_gt_bi, color='green', opacity=0.2)


out_ghd_mesh = paraheart.rendering()

# trimesh_current_bi = paraheart.rendering_bi_ventricle()
# trimesh_current_bi = pv.wrap(trimesh_current_bi)
# pl.add_mesh(trimesh_current_bi, color='blue', opacity=0.2)


trimesh_current_lv = trimesh.Trimesh(out_ghd_mesh.verts_packed().detach().cpu().numpy(), out_ghd_mesh.faces_packed().detach().cpu().numpy())
trimesh_current_lv = pv.wrap(trimesh_current_lv)
pl.add_mesh(trimesh_current_lv, color='lightblue', opacity=0.5)

# pl.add_points(points_bi.detach().cpu().numpy(), color='red', point_size=5)
pl.add_points(points_lv.detach().cpu().numpy(), color='yellow', point_size=5)
pl.show()


In [None]:

# sample_outoflv = points_outoflv_in_bbox[np.random.choice(points_outoflv_in_bbox.shape[0], sample_num*5, replace=False)]

convergence, Loss_dict_list  = paraheart.morphing2lvtarget(points_lv, points_outoflv_in_bbox, target_mesh=mesh_gt_lv, loss_dict 
                            = {'Loss_occupancy':1, 'Loss_normal_consistency':0.01, 'Loss_Laplacian':0.02, 'Loss_equaledge':0.01, 'Loss_rigid':0.1}, 
                            lr_start=1e-4, num_iter=2000, if_reset=True, if_fit_R=False, if_fit_s=True, if_fit_T=True, record_convergence=True)

In [None]:
pl = pv.Plotter(notebook=True)

out_ghd_mesh = paraheart.rendering()

trimesh_current_lv = trimesh.Trimesh(out_ghd_mesh.verts_packed().detach().cpu().numpy(), out_ghd_mesh.faces_packed().detach().cpu().numpy())
trimesh_current_lv = pv.wrap(trimesh_current_lv)

pl.add_mesh(trimesh_current_lv, color='lightblue', opacity=0.5)
pl.add_points(points_lv.detach().cpu().numpy(), color='yellow', point_size=5)
pl.show()