In [38]:
import os
import sys 
sys.path.append(os.path.join('..'))


import numpy as np
import torch
import torch.nn.functional as F

from ops.torch_warping import warp_img_torch_3D
from ops.torch_algebra import random_affine_matrix

import nibabel as nib

from pytorch3d.ops import sample_points_from_meshes, cubify
from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle

import pyvista as pv
pv.start_xvfb(wait=0)
pv.set_jupyter_backend('html')

import trimesh

from GHD.GHD_cardiac import GHD_Cardiac
from GHD import GHD_config

from data.dataset import ACDC_dataset_Simple
from torch.utils.data import DataLoader, Dataset

from losses import *
from ops.mesh_geometry import *

import data.data_utils as dut
from ops.medical_related import *


In [39]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [40]:
root_path = os.path.dirname(os.path.realpath('../..'))

root_path = os.path.join(root_path,'Dataset','ACDC')

dataset_path = os.path.join(root_path, 'train')

acda_dataset = ACDC_dataset_Simple(dataset_path=root_path, mode="train", process_device=device)

# output = acda_dataset[75] # Healthy case

# 0- 20 Dilated Cardiomyopathy (DCM)
# 21- 40 Hypertrophic Cardiomyopathy (HCM)
# 41- 60 Myocardial Infarction (MINF)
# 61- 80 Normal (NORM)

In [41]:
base_shape_path = '../canonical_shapes/Standard_LV_4055.obj'
bi_ventricle_path = '../canonical_shapes/Standard_BiV.obj'

# base_shape_path = 'metadata/Standard_LV.obj'
# bi_ventricle_path = 'metadata/Standard_BiV.obj'

cfg = GHD_config(base_shape_path=base_shape_path,
            num_basis=6**2, mix_laplacian_tradeoff={'cotlap':1.0, 'dislap':0.1, 'stdlap':0.1},
            device=device,
            if_nomalize=True, if_return_scipy=True, 
            bi_ventricle_path=bi_ventricle_path)

paraheart = GHD_Cardiac(cfg) # 


num_basis 36
mix_laplacian_tradeoff {'cotlap': 1.0, 'dislap': 0.1, 'stdlap': 0.1}


In [58]:

dataloader = DataLoader(acda_dataset, batch_size=1, shuffle=0)

target_F = 'ED'
# target_F = 'ES'

for i, example in enumerate(dataloader):

    if i != 52:
        continue


    image_ed_tem = example['img_ed'].to(device)
    image_es_tem = example['img_es'].to(device)
    label_ed_tem = example['seg_gt_ed'].to(device)
    label_es_tem = example['seg_gt_es'].to(device)

    affine_tem = example['affine'].to(device)


    print(target_F)

    if target_F == 'ED':
        image_tem = image_ed_tem
        label_tem = label_ed_tem

    else:
        image_tem = image_es_tem
        label_tem = label_es_tem

    image_tem = (image_tem - image_tem.min())/(image_tem.max() - image_tem.min())

    group = example['group']
    print(group)


    coordinate_map_tem = dut.get_coord_map_3d_normalized(image_tem.shape[-3:], affine_tem)



    Z_rv, Y_rv, X_rv = torch.where(label_tem[0, 0]==acda_dataset.label_value[1])
    Z_lv, Y_lv, X_lv = torch.where(label_tem[0, 0]==acda_dataset.label_value[2])
    Z_cav, Y_cav, X_cav = torch.where(label_tem[0, 0]==acda_dataset.label_value[3])
    Z_bg, Y_bg, X_bg = torch.where(label_tem[0, 0]==acda_dataset.label_value[0])

    Pt_rv = coordinate_map_tem[0, Z_rv, Y_rv, X_rv]
    Pt_lv = coordinate_map_tem[0, Z_lv, Y_lv, X_lv]
    Pt_cav = coordinate_map_tem[0, Z_cav, Y_cav, X_cav]
    Pt_bg = coordinate_map_tem[0, Z_bg, Y_bg, X_bg]


    geom_dict = get_4chamberview_frame(Pt_cav, Pt_lv, Pt_rv, given_u2d_axis=affine_tem[0, :3, 2])
    inital_affine = geom_dict['target_affine']


    points_bi = torch.cat([Pt_rv, Pt_lv], dim=0)
    points_lv = Pt_lv
    points_outoflv = torch.cat([Pt_rv, Pt_cav, Pt_bg], dim=0)
    bbox_lv = torch.stack([Pt_lv.min(dim=0)[0]-0.05, Pt_lv.max(dim=0)[0]+0.05], dim=-1)
    points_outoflv_in_bbox = points_outoflv[(points_outoflv[:,0]>bbox_lv[0,0]) & (points_outoflv[:,0]<bbox_lv[0,1]) & (points_outoflv[:,1]>bbox_lv[1,0]) & (points_outoflv[:,1]<bbox_lv[1,1]) & (points_outoflv[:,2]>bbox_lv[2,0]) & (points_outoflv[:,2]<bbox_lv[2,1])]

    break

paraheart.reset_affine_param()
paraheart.reset_GHD_param()

paraheart.R = matrix_to_axis_angle(inital_affine[...,:3,:3].to(paraheart.device)).view(paraheart.R.shape)
paraheart.T = inital_affine[...,:3,3].to(paraheart.device).view(paraheart.T.shape)

ED
['MINF']


In [59]:
sample_num = 2000

sample_lv = points_lv[np.random.choice(points_lv.shape[0], sample_num, replace=False)]
sample_outoflv = points_outoflv_in_bbox[np.random.choice(points_outoflv_in_bbox.shape[0], sample_num, replace=False)]
# paraheart.global_registration_lv(sample_lv.detach().cpu().numpy())
paraheart.global_registration_biv(points_bi.detach().cpu().numpy()[np.random.choice(points_bi.shape[0], 
                                                                                    sample_num, replace=False)])

paraheart.global_registration_lv(sample_lv.detach().cpu().numpy())

{'rot': array([[ 0.12123755,  0.99233384,  0.02397924],
        [ 0.99257122, -0.12094781, -0.01319055],
        [-0.01018919,  0.02540029, -0.99962543]]),
 'scale': 1.2717972882660336,
 't': array([ 0.14342066, -0.0009006 , -0.09679929])}

In [60]:
out_ghd_mesh = paraheart.rendering()
affine_tem_np = paraheart.get_affine_matrix().detach().cpu().numpy()
affine_tem_inv = np.linalg.inv(affine_tem_np)

coordinate_map_np = coordinate_map_tem[0].detach().cpu().numpy()
coordinate_map_np = np.dot(coordinate_map_np, affine_tem_inv[:3,:3].T) + affine_tem_inv[:3,3]

pl = pv.Plotter(notebook=True)
interval = 1
if image_tem.shape[-3] > 20:
    interval = 2
if image_tem.shape[-3] > 256:
    interval = 10
for i in range(0, image_tem.shape[-3], interval):

    x, y, z = coordinate_map_np[i,...,0], coordinate_map_np[i,...,1], coordinate_map_np[i,...,2]



    grid = pv.StructuredGrid(x, y, z)

    color_gt = (label_tem[0,0,i].cpu().numpy().T.flatten() ==2).astype(np.float32)
    
    raw_image = image_tem[0,0,i].cpu().numpy().T.flatten()

    color_opacity = np.ones_like(color_gt)*0.8

    color_opacity[color_gt == 0] = 0.2


    color_gt = raw_image*(1-color_gt) + color_gt


    pl.add_mesh(grid, scalars = color_gt, cmap = 'gray_r',
                show_scalar_bar = False, opacity = color_opacity, clim=[0,1])
    


# trimesh_current_bi = paraheart.rendering_bi_ventricle()
# trimesh_current_bi = pv.wrap(trimesh_current_bi)
# pl.add_mesh(trimesh_current_bi, color='blue', opacity=0.2)

# trimesh_gt_lv = trimesh.Trimesh(mesh_gt_lv.verts_packed().detach().cpu().numpy(), mesh_gt_lv.faces_packed().detach().cpu().numpy())
# pl.add_mesh(trimesh_gt_lv, color='lightgreen', opacity=0.1)

trimesh_current_lv = trimesh.Trimesh(out_ghd_mesh.verts_packed().detach().cpu().numpy(), out_ghd_mesh.faces_packed().detach().cpu().numpy())
trimesh_current_lv.apply_transform(affine_tem_inv)
pl.add_mesh(trimesh_current_lv, color='lightblue', opacity=0.8, show_edges=True, show_vertices=False)

sample_lv_np = sample_lv.detach().cpu().numpy()
sample_outoflv_np = sample_outoflv.detach().cpu().numpy()
sample_lv_np = np.dot(sample_lv_np, affine_tem_inv[:3,:3].T) + affine_tem_inv[:3,3]
sample_outoflv_np = np.dot(sample_outoflv_np, affine_tem_inv[:3,:3].T) + affine_tem_inv[:3,3]
pl.add_points(sample_lv_np, color='green', point_size=5)
pl.add_points(sample_outoflv_np, color='red', point_size=2)


pl.add_mesh(pv.Box(bounds=[-1, 1, -1, 1, -1, 1]).outline(), color='black')

# pl.add_points(lv_cavity_center.cpu().numpy(), color='red', point_size=10)

# flip the z axis


pl.show()

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

### Graph Harmonic Morphing

In [61]:
# loss_dict = {'Loss_occupancy':1., 'Loss_Chamfer_P0':0., 'Loss_Chamfer_N1':0., 'Loss_normal_consistency':0.01, 'Loss_Laplacian':0.01, 'Loss_equaledge':0.01, 'Loss_rigid':0.01}


loss_dict = {'Loss_occupancy':1., 'Loss_normal_consistency':0.01, 'Loss_Laplacian':0.01, 'Loss_thickness':0.02}


current_mesh, loss_dict = paraheart.morphing2lvtarget(points_lv, points_outoflv_in_bbox, target_mesh=None, 
                                    loss_dict = loss_dict,
                                    lr_start=0.5*1e-2, num_iter=500, num_sample=10000, NP_ratio=1,
                                    if_reset=True, if_fit_R=True, if_fit_s=True, if_fit_T=True)

Total Loss 0.0551: 100%|██████████| 500/500 [00:37<00:00, 13.50it/s]

fittings done, the final loss is 0.055081





### Evaluation 

In [62]:
print('Group:', group)
print('Frame:', target_F)
final_dice = 1 - paraheart.dice_evaluation(points_lv, points_outoflv)
print('Final Dice: %.4f'%(final_dice*100)+'%')


Group: ['MINF']
Frame: ED
Final Dice: 89.5250%


In [63]:
out_ghd_mesh = paraheart.rendering()


affine_tem_np = paraheart.get_affine_matrix().detach().cpu().numpy()
affine_tem_inv = np.linalg.inv(affine_tem_np)

coordinate_map_np = coordinate_map_tem[0].detach().cpu().numpy()
coordinate_map_np = np.dot(coordinate_map_np, affine_tem_inv[:3,:3].T) + affine_tem_inv[:3,3]

pl = pv.Plotter(notebook=True)

interval = 1
if image_tem.shape[-3] > 20:
    interval = 2
if image_tem.shape[-3] > 256:
    interval = 10
for i in range(0, image_tem.shape[-3], interval):

    x, y, z = coordinate_map_np[i,...,0], coordinate_map_np[i,...,1], coordinate_map_np[i,...,2]

    grid = pv.StructuredGrid(x, y, z)

    color_gt = (label_tem[0,0,i].cpu().numpy().T.flatten() ==2).astype(np.float32)
    
    raw_image = image_tem[0,0,i].cpu().numpy().T.flatten()

    color_opacity = np.ones_like(color_gt)*0.8

    color_opacity[color_gt == 0] = 0.2


    color_gt = raw_image*(1-color_gt) + color_gt


    pl.add_mesh(grid, scalars = color_gt, cmap = 'gray_r',
                show_scalar_bar = False, opacity = color_opacity, clim=[0,1])
    

points_lv_np = points_lv.detach().cpu().numpy()
points_lv_np = np.dot(points_lv_np, affine_tem_inv[:3,:3].T) + affine_tem_inv[:3,3]
pl.add_points(points_lv_np, color='lightgreen', point_size=5, opacity=0.6, render_points_as_spheres=True)

points_rv_np = Pt_rv.detach().cpu().numpy()
points_rv_np = np.dot(points_rv_np, affine_tem_inv[:3,:3].T) + affine_tem_inv[:3,3]
pl.add_points(points_rv_np, color='blue', point_size=3, opacity=0.2, render_points_as_spheres=True)


trimesh_current_lv = trimesh.Trimesh(vertices=out_ghd_mesh.verts_packed().detach().cpu().numpy(), faces=out_ghd_mesh.faces_packed().detach().cpu().numpy())
# current_trimesh.export('current_mesh_ghd.obj')
trimesh_current_lv.apply_transform(affine_tem_inv)

pl.add_mesh(trimesh_current_lv, color='lightpink', 
            opacity=0.9, lighting=True, show_edges=True)

# trimeshtarget = trimesh.Trimesh(vertices=meshtarget.verts_packed().detach().cpu().numpy(), faces=meshtarget.faces_packed().detach().cpu().numpy())

# pl.add_mesh(pv.wrap(trimeshtarget), color='lightgreen', opacity=0.5)

pl.show()

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [64]:
pl = pv.Plotter(notebook=True)
pl.add_mesh(trimesh_current_lv, color='white', opacity=0.2, lighting=True, show_edges=True)
# pl.add_mesh(pv.Box(bounds=[-0.4, 0.4, -0.4, 0.4, -0.4, 0.4]).outline(), color='white', line_width=0.1, opacity=1)
pl.camera.azimuth = 0
pl.camera.elevation = -10
pl.show(screenshot='output/'+str(group[0])+'_'+target_F+'.png')

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [None]:
from ops.mesh_geometry import MeshThickness
thicknesser = MeshThickness()
thickness,_,indx,sign = thicknesser.forward(current_mesh)


Thickness: tensor(2.0331, device='cuda:0', grad_fn=<MulBackward0>)


In [51]:
current_mesh = paraheart.rendering()


pl = pv.Plotter(notebook=True)

current_trimesh = trimesh.Trimesh(vertices=current_mesh.verts_packed().detach().cpu().numpy(), faces=current_mesh.faces_packed().detach().cpu().numpy())
scalars = (thickness*100).detach().cpu().numpy()
current_trimesh.apply_transform(affine_tem_inv)
pl.add_mesh(current_trimesh, scalars=scalars, cmap='viridis', show_scalar_bar=True, lighting=True, clim=[-20, 20])
pl.add_axes()
pl.show()



EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [25]:
# differentiable volume 

current_normals = current_mesh.faces_normals_packed()
current_area = current_mesh.faces_areas_packed()
current_bary = current_mesh.verts_packed()[current_mesh.faces_packed()].mean(dim=1)

volume = ((current_bary*current_normals).sum(dim=-1)*current_area).sum()/3

print(volume.item())



0.13018786907196045


In [26]:
## trimesh volume

current_trimesh = trimesh.Trimesh(vertices=current_mesh.verts_packed().detach().cpu().numpy(), faces=current_mesh.faces_packed().detach().cpu().numpy())

current_trimesh.volume

# endo volume 

endo_volume = current_trimesh.convex_hull.volume - current_trimesh.volume

print('%.3f'%(endo_volume*(100**3)/1000),'ml')

62.914 ml


In [16]:
current_trimesh.convex_hull.volume - current_trimesh.volume

0.044275176432237995

In [27]:
(156.739
 - 62.914)/156.739




0.5986066007821921