In [21]:
import os
import sys 
sys.path.append(os.path.join('..'))


import numpy as np
import torch
import torch.nn.functional as F

from ops.torch_warping import warp_img_torch_3D
from ops.torch_algebra import random_affine_matrix

import nibabel as nib

from pytorch3d.ops import sample_points_from_meshes, cubify
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle

import pyvista as pv
pv.start_xvfb(wait=0)
pv.set_jupyter_backend('html')

import trimesh

import pickle

from GHD.GHD_cardiac import GHD_Cardiac
from GHD import GHD_config

from data_process.dataset_real_scaling import *


In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [31]:
root_path = os.path.dirname(os.path.realpath('..'))

root_path = os.path.join(root_path,'Dataset','ACDC')

dataset_path = os.path.join(root_path, 'train')

acda_dataset = ACDC_dataset_Simple(dataset_path=root_path, mode="train", process_device=device)


mesh_acdc = ACDC_dataset(dataset_path=root_path, mode="train", process_device=device, label_value_list=[[2]])

output = acda_dataset[75] # Healthy case

In [32]:
point_list = point_cloud_extractor(output['seg_gt_ed'],  [0,1,2,3], output['window'], spacing=200, coordinate_order = 'zyx')


pl = pv.Plotter(notebook=True)
color_list = ['lightblue', 'lightsalmon', 'lightgreen']
for point, color in zip(point_list[1:], color_list):
    pl.add_points(point.cpu().numpy(), color=color, point_size=10, opacity=0.2)

img = output['img_ed']
window = output['window']

for i in range(img.shape[-1]):
    if i%2 == 0:
        continue
    meshgrid_x, meshgrid_y, meshgrid_z = meshgrid_from_slices(img, i, window, rescalar = 200.0)
    grid = pv.StructuredGrid(meshgrid_z.cpu().numpy(), meshgrid_y.cpu().numpy(), meshgrid_x.cpu().numpy())
    scalars = img[0,0,:,:,i].cpu().numpy().T # Becareful about the transpose here, the original image is in zyx, but the meshgrid is in xyz
    # grey to rgb

    pl.add_mesh(grid, scalars=scalars.reshape(-1), cmap = 'gray',opacity=0.1, show_scalar_bar=False)


# trimeshtarget = trimesh.Trimesh(vertices=meshtarget.verts_packed().detach().cpu().numpy(), faces=meshtarget.faces_packed().detach().cpu().numpy())

# pl.add_mesh(trimeshtarget, color='lightgreen', opacity=0.5)

pl.show()


EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [33]:
root_path = os.path.dirname(os.path.realpath('.'))

base_shape_path = 'canonical_shapes/Standard_LV_2000.obj'
base_shape_path = os.path.join(root_path, base_shape_path)
bi_ventricle_path = 'canonical_shapes/Standard_BiV.obj'
bi_ventricle_path = os.path.join(root_path, bi_ventricle_path)

# base_shape_path = 'metadata/Standard_LV.obj'
# bi_ventricle_path = 'metadata/Standard_BiV.obj'

cfg = GHD_config(base_shape_path=base_shape_path,
            num_basis=6**2, mix_laplacian_tradeoff={'cotlap':1.0, 'dislap':0.1, 'stdlap':0.1},
            device='cuda:0',
            if_nomalize=True, if_return_scipy=True, 
            bi_ventricle_path=bi_ventricle_path)

paraheart = GHD_Cardiac(cfg) # 

GHD config:
base_shape_path /home/yihao/Document/GHDHeart/canonical_shapes/Standard_LV_2000.obj
num_basis 36
device cuda:0
mix_laplacian_tradeoff {'cotlap': 1.0, 'dislap': 0.1, 'stdlap': 0.1}
if_lap_nomalize True
eign_path None
if_nomalize True
if_return_scipy True
bi_ventricle_path /home/yihao/Document/GHDHeart/canonical_shapes/Standard_BiV.obj


BiVentricle Global Registration 

In [34]:
# load initial orientation according to dataset

init_affine_path = os.path.join(root_path, 'canonical_shapes/acdc_init_affine.pkl')

with open(init_affine_path, 'rb') as f:
    initial_orientation = pickle.load(f)

R = initial_orientation[:3,:3]
T = initial_orientation[:3,3]

paraheart.R = matrix_to_axis_angle(torch.from_numpy(R).to(paraheart.device)).view(paraheart.R.shape)
paraheart.T = torch.from_numpy(T).to(paraheart.device).view(paraheart.T.shape)




In [35]:
points_bi = torch.cat(point_list[1:3], dim=0)

mesh_gt_bi_sample = points_bi.detach().cpu().numpy()[np.random.choice(points_bi.shape[0], 500, replace=False)]

paraheart.global_registration_biv(mesh_gt_bi_sample)



points_lv = point_list[2]

mesh_gt_lv_sample = points_lv.detach().cpu().numpy()[np.random.choice(points_lv.shape[0], 2000, replace=False)]

paraheart.global_registration_lv(mesh_gt_lv_sample)


{'rot': array([[-0.19994186,  0.16802241, -0.96529359],
        [ 0.92057801, -0.30512614, -0.24379122],
        [-0.3354987 , -0.93737213, -0.09367025]]),
 'scale': 1.2713718228778272,
 't': array([-0.15720864, -0.12574415,  0.02975088])}

In [36]:
current_mesh = paraheart.rendering()

pl = pv.Plotter(notebook=True)

color_list = ['lightblue', 'lightsalmon', 'lightgreen']
for point, color in zip(point_list[1:], color_list):
    pl.add_points(point.cpu().numpy(), color=color, point_size=10, opacity=0.2)

img = output['img_ed']
window = output['window']

for i in range(img.shape[-1]):
    if i%4 > 0:
        continue
    meshgrid_x, meshgrid_y, meshgrid_z = meshgrid_from_slices(img, i, window, rescalar = 200.0)
    grid = pv.StructuredGrid(meshgrid_z.cpu().numpy(), meshgrid_y.cpu().numpy(), meshgrid_x.cpu().numpy())
    scalars = img[0,0,:,:,i].cpu().numpy().T # Becareful about the transpose here, the original image is in zyx, but the meshgrid is in xyz
    # grey to rgb

    pl.add_mesh(grid, scalars=scalars.reshape(-1), cmap = 'gray',opacity=0.1, show_scalar_bar=False)

current_trimesh = trimesh.Trimesh(vertices=current_mesh.verts_packed().detach().cpu().numpy(), faces=current_mesh.faces_packed().detach().cpu().numpy())
current_trimesh = pv.wrap(current_trimesh)

current_trimesh_bi = paraheart.rendering_bi_ventricle()
current_trimesh_bi = pv.wrap(current_trimesh_bi)

pl.add_mesh(current_trimesh, color='lightblue')
pl.add_mesh(current_trimesh_bi, color='lightgreen')
pl.add_axes()
pl.show()

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

Graph Harmonic Morphing

In [37]:
points_lv = point_list[2]
bbox_lv = torch.stack([points_lv.min(dim=0)[0], points_lv.max(dim=0)[0]], dim=0).T
rescale = 1.1
bbox_lv_center =  bbox_lv.mean(-1)
bbox_lv = torch.stack([bbox_lv_center-rescale*(bbox_lv_center-bbox_lv[:,0]), bbox_lv_center+rescale*(bbox_lv[:,1]-bbox_lv_center)], dim=-1)

In [38]:
points_outoflv = torch.cat([point_list[0], point_list[1]], dim=0)
points_outoflv_in_bbox = points_outoflv[(points_outoflv[:,0]>bbox_lv[0,0]) & (points_outoflv[:,0]<bbox_lv[0,1]) & (points_outoflv[:,1]>bbox_lv[1,0]) & (points_outoflv[:,1]<bbox_lv[1,1]) & (points_outoflv[:,2]>bbox_lv[2,0]) & (points_outoflv[:,2]<bbox_lv[2,1])]
points_outoflv_in_bbox = torch.cat([points_outoflv_in_bbox, point_list[-1]], dim=0)

In [39]:
mesh_after_globalreg = paraheart.rendering()

In [40]:
# loss_dict = {'Loss_occupancy':1., 'Loss_Chamfer_P0':0., 'Loss_Chamfer_N1':0., 'Loss_normal_consistency':0.01, 'Loss_Laplacian':0.01, 'Loss_equaledge':0.01, 'Loss_rigid':0.01}


loss_dict = {'Loss_occupancy':1., 'Loss_normal_consistency':0.01, 'Loss_Laplacian':0.1, 'Loss_thickness':0.02}


current_mesh, loss_dict = paraheart.morphing2lvtarget(points_lv, points_outoflv_in_bbox, target_mesh=None, 
                                    loss_dict = loss_dict,
                                    lr_start=1e-4, num_iter=1000, num_sample=10000, NP_ratio=1,
                                    if_reset=True, if_fit_R=False, if_fit_s=True, if_fit_T=True)

Total Loss: 0.0000:   0%|          | 0/1000 [00:00<?, ?it/s]

Total Loss: 0.0000: 100%|██████████| 1000/1000 [00:16<00:00, 61.80it/s]

fittings done, the final loss is 0.046567





In [41]:
final_dice = 1 - paraheart.dice_evaluation(points_lv, torch.cat([point_list[0], point_list[1]], dim=0))
print('Final Dice: %.4f'%(final_dice*100)+'%')


Final Dice: 95.4160%


In [42]:
current_mesh = paraheart.rendering()

pl = pv.Plotter(notebook=True)

color_list = ['lightgreen']
for point, color in zip(point_list[2:3], color_list):
    pl.add_points(point.cpu().numpy(), color=color, point_size=5, opacity=0.6, render_points_as_spheres=True)

img = output['img_ed']
window = output['window']

for i in range(img.shape[-1]):
    if i%2 > 0:
        continue
    meshgrid_x, meshgrid_y, meshgrid_z = meshgrid_from_slices(img, i, window, rescalar = 200.0)
    grid = pv.StructuredGrid(meshgrid_z.cpu().numpy(), meshgrid_y.cpu().numpy(), meshgrid_x.cpu().numpy())
    scalars = img[0,0,:,:,i].cpu().numpy().T # Becareful about the transpose here, the original image is in zyx, but the meshgrid is in xyz
    # grey to rgb

    pl.add_mesh(grid, scalars=scalars.reshape(-1), cmap = 'gray',opacity=0.4, show_scalar_bar=False)

current_trimesh = trimesh.Trimesh(vertices=current_mesh.verts_packed().detach().cpu().numpy(), faces=current_mesh.faces_packed().detach().cpu().numpy())
current_trimesh.export('current_mesh_ghd.obj')
current_trimesh = pv.wrap(current_trimesh)

pl.add_mesh(current_trimesh, color='lightpink', opacity=0.9, lighting=True, show_edges=True)

# trimeshtarget = trimesh.Trimesh(vertices=meshtarget.verts_packed().detach().cpu().numpy(), faces=meshtarget.faces_packed().detach().cpu().numpy())

# pl.add_mesh(pv.wrap(trimeshtarget), color='lightgreen', opacity=0.5)

pl.show()

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [43]:
from ops.mesh_geometry import MeshThickness
thicknesser = MeshThickness()
thickness,_,indx,sign = thicknesser.forward(current_mesh)

In [44]:
current_mesh = paraheart.rendering()

pl = pv.Plotter(notebook=True)

current_trimesh = trimesh.Trimesh(vertices=current_mesh.verts_packed().detach().cpu().numpy(), faces=current_mesh.faces_packed().detach().cpu().numpy())
scalars = (sign).detach().cpu().numpy()
pl.add_mesh(current_trimesh, scalars=scalars, cmap='viridis', show_scalar_bar=True, lighting=True, clim=[-0.2, 0.2])
pl.add_axes()
pl.show()



EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [45]:
# differentiable volume 

current_normals = current_mesh.faces_normals_packed()
current_area = current_mesh.faces_areas_packed()
current_bary = current_mesh.verts_packed()[current_mesh.faces_packed()].mean(dim=1)

volume = ((current_bary*current_normals).sum(dim=-1)*current_area).sum()/3

print(volume.item())



0.12181388586759567


In [46]:
## trimesh volume

current_trimesh = trimesh.Trimesh(vertices=current_mesh.verts_packed().detach().cpu().numpy(), faces=current_mesh.faces_packed().detach().cpu().numpy())

current_trimesh.volume

# endo volume 

endo_volume = current_trimesh.convex_hull.volume - current_trimesh.volume

print('%.3f'%(endo_volume*(100**3)/1000),'ml')

163.991 ml


In [47]:
# endo volume 

endo_volume = current_trimesh.convex_hull.volume - current_trimesh.volume

print('%.3f'%(endo_volume*(100**3)/1000),'ml')

163.991 ml
