# Set-up and some packages for manipulation

In [None]:
import unittest
import random
import copy as cp

import nrrd
import numpy as np
import astra
import scipy.spatial
from scipy.spatial.distance import cdist
import torchio

import torch
import torchvision.transforms as torch_transform

from pykeops.torch import Vi, Vj
from pykeops.torch import LazyTensor

from skimage import measure
from skimage.util import random_noise
import point_cloud_utils as pcu

from scipy.spatial.transform import Rotation as R
import mcubes
from simpleicp import PointCloud, SimpleICP

from src.algs.arm import lv_indicator
from src.tools.cmf.cmf import CMF_3D
from src.tools.recon.projector import forward_projector, backward_projector
from src.tools.manip.manip import normalize_volume

# data fetching and handling
from data.check_database import load_remote_data
from data.fetch_data import fetch_data
from src.tools.data.loadvolumes import LoadVolumes

# Data types and set-up for loading the data from a remote html server

In [None]:
lv_model_volume = None
lv_model_frames = None
lv_motion_frames = None
lv_volume = None

resolution = 64

volume = np.zeros([resolution, resolution, resolution])
params = dict(a=1, c=2, sigma=-1)
transform_params = [np.eye(3, 3), [16, 16, 0], 1.5]

recon_mode = 'basic'
fprojector = forward_projector(recon_mode)

lv_model_volume = lv_indicator(volume, params, transform_params)
lv_model_frames = fprojector(lv_model_volume)
lv_motion_frames = np.zeros(lv_model_frames.shape)

# Individual patient fetching from remote server

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# fetch specific patient data
dicom_name = datasets['raw/']['turkey_par/'][10]
data_url = url + '/raw/' + 'turkey_par/' + dicom_name

# fetch the data from remote
data = fetch_data(data_url)

# load data with the dicom loader
frames, data_loaded = dicom_loader.LoadSinglePatient(data)

# normalizing the frame values
normalize_volume(frames)
frames = frames + 1

assert (data_loaded)

# Series of patient fetching from remote server

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# read all filenames from the url
from bs4 import BeautifulSoup
import requests
page = requests.get(url + '/recon/' + 'spie_2024/' + 'misc/' + 'label/')
soup = BeautifulSoup(page.content, 'html.parser')
label_names = []
for label_ref in soup.find_all('a'):
    label_names.append(label_ref.get('href'))

page = requests.get(url + '/recon/' + 'spie_2024/' + 'misc/' + 'data/')
soup = BeautifulSoup(page.content, 'html.parser')
data_names = []
for label_ref in soup.find_all('a'):
    data_names.append(label_ref.get('href'))

subjects = []
subjects_data = []

# fetch specific patient data
for index in range(len(data_names)):

    dicom_name = data_names[index]
    label_name = label_names[index]
    data_url = url + '/recon/' + 'spie_2024/' + 'misc/' + 'data/' + dicom_name
    label_url = url + '/recon/' + 'spie_2024/' + 'misc/' + 'label/' + label_name
    
    # fetch the data from remote
    data = fetch_data(data_url)
    lab = fetch_data(label_url)
    
    # load data with the dicom loader
    volume, data_loaded = dicom_loader.LoadSinglePatient(data)
    header = nrrd.read_header(lab)
    labels = nrrd.read_data(header, lab)
    
    # looks like the label export is a bit tricky so loading shall be updated
    prob_val_1 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0))
    prob_val_2 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0))
    
    if prob_val_1 > prob_val_2:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0)
    else:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0)
        
    subject = {
        'spect' : volume,
        'left_ventricle' : labels
    }
    subjects.append(subject)
    
    age, gender, weight, height = dicom_loader.CalculatePatientStatistics()
    subject_data = {
        'age' : age,
        'gender' : gender,
        'weight' : weight,
        'height' : height
    }
    subjects_data.append(subject_data)

    print("Volume shape: ", volume.shape, "Labels shape:", labels.shape)

    # normalizing the frame values
    normalize_volume(volume)

assert (data_loaded)

In [None]:
print(len(subjects))

In [None]:
def compute_dataset_means(subjects_data, indices = None):
    ages = np.zeros(0)
    weights = np.zeros(0)
    heights = np.zeros(0)
    
    num_male = 0
    num_female = 0
    
    # cleaning data
    for i in indices:
        age = subjects_data[i]['age'] 
        if age > 0:
            ages = np.append(ages, age)
        
        weight = subjects_data[i]['weight']
        if weight > 0:
            weights = np.append(weights, weight)
        
        height = subjects_data[i]['height']
        if height > 0:
            heights = np.append(heights, height)
        
        if subjects_data[i]['gender'] == 'F':
            num_female += 1
        elif subjects_data[i]['gender'] == 'M':
            num_male += 1

    return num_female, num_male, ages, heights, weights        

In [None]:
num_female, num_male, ages, heights, weights = compute_dataset_means(subjects_data)

print("Number of females: ", num_female, "Number of males: ", num_male)
print("Average age: ", np.mean(ages))
print("Average height: ", np.mean(heights))
print("Average weight: ", np.mean(weights))


In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# read all filenames from the url
from bs4 import BeautifulSoup
import requests
page = requests.get(url + '/recon/' + 'spie_2024/' + 'bela/' + 'label/')
soup = BeautifulSoup(page.content, 'html.parser')
label_names = []
for label_ref in soup.find_all('a'):
    label_names.append(label_ref.get('href'))

page = requests.get(url + '/recon/' + 'spie_2024/' + 'bela/' + 'data/')
soup = BeautifulSoup(page.content, 'html.parser')
data_names = []
for label_ref in soup.find_all('a'):
    data_names.append(label_ref.get('href'))

subjects_bela = []
subjects_bela_data = []

# fetch specific patient data
for index in range(len(data_names)):

    dicom_name = data_names[index]
    label_name = label_names[index]
    data_url = url + '/recon/' + 'spie_2024/' + 'bela/' + 'data/' + dicom_name
    label_url = url + '/recon/' + 'spie_2024/' + 'bela/' + 'label/' + label_name
    
    # fetch the data from remote
    data = fetch_data(data_url)
    lab = fetch_data(label_url)
    
    # load data with the dicom loader
    volume, data_loaded = dicom_loader.LoadSinglePatient(data)
    header = nrrd.read_header(lab)
    labels = nrrd.read_data(header, lab)
    
    # looks like the label export is a bit tricky so loading shall be updated
    prob_val_1 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0))
    prob_val_2 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0))
    
    if prob_val_1 > prob_val_2:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0)
    else:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0)
    
    subject = {
        'spect' : volume,
        'left_ventricle' : labels
    }
    subjects_bela.append(subject)
    
    age, gender, weight, height = dicom_loader.CalculatePatientStatistics()
    subject_data = {
        'age' : age,
        'gender' : gender,
        'weight' : weight,
        'height' : height
    }
    subjects_bela_data.append(subject_data)

    print("Volume shape: ", volume.shape, "Labels shape:", labels.shape)

    # normalizing the frame values
    normalize_volume(volume)

assert (data_loaded)

In [None]:
print(len(subjects_bela))
print(len(subjects_bela_data))

num_female, num_male, ages, heights, weights = compute_dataset_means(subjects_bela_data, [0, 1, 2, 3, 4, -9, -4, -3, -2, -1]) # MPH : list(range(15, 25)), Par: [-5, -6, -7, -7], cardiod list(range(4, 15)), cardioc [0, 1, 2, 3, 4, -9, -4, -3, -2, -1]

print("Number of females: ", num_female, "Number of males: ", num_male)
print("Average age: ", np.mean(ages))
print("Average height: ", np.mean(heights))
print("Average weight: ", np.mean(weights))

# Generating shape priors with the proposed cardiac model, the size proportional to the frames that are fetched already

In [None]:
num_frames, width, height = frames.shape

bprojectpor = backward_projector()
lv_volume = bprojectpor(frames)

In [None]:
# from mpl_toolkits import mplot3d
# import matplotlib.pyplot as plt
# %matplotlib notebook
# 
# # turkey_par, 2 (index: 20, 20:40, 40:60), 4(index: 32, 20:40, 30:50), 7(index: 25, 25:45, 40:60), 10(index: 32, 20:40, 30:50)
# index = 32
# 
# fig = plt.figure()
# fig.set_size_inches((1, 1))
# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.set_axis_off()
# fig.add_axes(ax)
#     
# plt.imshow(lv_volume[index, 20:40, 30:50], aspect='equal')
# plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/poster/" + "patient_10" + ".png", bbox_inches='tight', pad_inches=0)

In [None]:

lv_volume = np.random.rand(64, 64, 64)
normalize_volume(lv_volume)

num_prior = 9
shape_priors = np.zeros([num_prior, *lv_volume.shape])

wall_thickness = np.random.uniform(0.3, 1.0, num_prior)
rot_angles = np.random.uniform(0, 2 * np.pi, num_prior)
curvature = np.random.uniform(1.5, 3, num_prior)
sigmas = np.random.uniform(-0.5, -1, num_prior)

for i in range(num_prior):
    volume = np.zeros([*lv_volume.shape])
    params = dict(a=wall_thickness[i], c=curvature[i], sigma=sigmas[i])
    rot_mx = R.from_quat([0, 0, np.sin(rot_angles[i]), np.cos(rot_angles[i])])

    transform_params = [np.eye(3, 3), [16, 16, 0], 1.5]
    shape_priors[i] = lv_indicator(volume, params, transform_params, a_plot=False)

    recon_mode = 'basic'
    fprojector = forward_projector(recon_mode)

    frames = fprojector(shape_priors[i])
    
# lv_volume = shape_priors[3]

# Implementation of the functions being tested, first their packages loaded

In [None]:
def GaussKernel(sigma):
    x, y, b = Vi(0, 2), Vj(1, 2), Vj(2, 2)
    gamma = 1 / (2 * sigma * sigma)
    D2 = x.sqdist(y) / (2 * 64 * 64)
    K = (-D2 * gamma).exp()
    return ((0.3989 / sigma) * K * b).sum_reduction(axis=1)

In [None]:
def GaussKernel(x, y, sigma):
    gamma = 1 / (2 * sigma * sigma)

    if len(x.shape) > 3 or len(y.shape) > 3:
        D2 = torch.zeros((x.shape[0], y.shape[0], *x.shape[1:]))
        for i in range(x.shape[0]):
            for j in range(y.shape[0]):
                D2[i, j] = torch.abs(x[i] - y[j]) ** 2 / (2 * 64 * 64)

        K = (-D2 * gamma).exp()
        return torch.sum((0.3989 / sigma) * K, dim=(2, 3, 4))
    else:
        D2 = torch.abs(x - y) ** 2 / (2 * 64 * 64)
        K = (-D2 * gamma).exp()
        return torch.sum((0.3989 / sigma) * K)

In [None]:
# orthogonal complement calculation based on https://github.com/statsmodels/statsmodels/issues/3039 
def orthogonal_complement(x, normalize=False, threshold=1e-15):
    """Compute orthogonal complement of a matrix

    this works along axis zero, i.e. rank == column rank,
    or number of rows > column rank
    otherwise orthogonal complement is empty

    TODO possibly: use normalize='top' or 'bottom'

    """
    r, c = x.shape
    if r < c:
        import warnings
        warnings.warn('fewer rows than columns', UserWarning)

    # we assume svd is ordered by decreasing singular value, o.w. need sort
    s, v, d = torch.linalg.svd(x)
    rank = torch.sum(torch.where(torch.diag(v) > threshold, 1.0, 0.0))

    oc = s[:, rank:]

    if normalize:
        k_oc = oc.shape[1]
        oc = oc.dot(torch.linalg.inv(oc[:k_oc, :]))
    return oc

In [None]:
x1 = np.random.rand(1000, 3)
x2 = np.random.rand(1000, 3)

M = x1.T @ x2

U, S, V = np.linalg.svd(M)

In [None]:
proj = V @ U.T

In [None]:
proj.shape

In [None]:
proj

In [None]:
proj_diff = np.linalg.norm(x1 - x2 @ proj)
diff = np.linalg.norm(x1 - x2)

print(proj_diff)
print(diff)

In [None]:
from scipy.spatial import procrustes
nx1, nx2,  M = procrustes(x1, x2)

# Create the shape distribution here

In [None]:
shape_priors.shape

In [None]:
num_samples = shape_priors.shape[0]

num_verts = 100
rand_positions = torch.zeros((shape_priors.shape[0], num_verts, 3))

for i in range(shape_priors.shape[0]):
    verts, faces, normals, values = measure.marching_cubes(shape_priors[i], 0)
    fid, bc = pcu.sample_mesh_random(verts, faces, num_verts)
    rand_positions[i] = torch.from_numpy(pcu.interpolate_barycentric_coords(faces, fid, bc, verts))

In [None]:
rand_positions.shape

In [None]:
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
%matplotlib notebook

# get random lv prior surface points
lv_surf = rand_positions[3]

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter(lv_surf[:, 0], lv_surf[:, 1], lv_surf[:, 2])

In [None]:
z_i = rand_positions  # generated random positions on the lv surface
z_i_a = torch.zeros(rand_positions.shape)  # aligned shapes

# compute the scaled and mean shapes of each training sample
z_i_centroid = torch.mean(z_i, dim=1)
z_i_size = torch.linalg.norm(z_i - z_i_centroid[:, None, :], dim=(1, 2))
z_i_s = z_i / z_i_size[:, None, None]

# do the algorithm 3.2 in [1], meaning that aligning each shape to the first one
prev_mean_shape = z_i_s[0]
mean_shape = torch.zeros(*prev_mean_shape.shape)
iteration = 0
while torch.linalg.norm(prev_mean_shape - mean_shape) >= 1e-2 and iteration < 10:
    for i in range(1, z_i.shape[0]):
        M = prev_mean_shape.t() @ z_i_s[i] # do the Procrustes analysis on the selected mean and current
        U, D, V = torch.linalg.svd(M)
        proj_rot_ref = V @ U.t()
        
        z_i_a[i] = z_i_s[i] @ proj_rot_ref
    
    prev_mean_shape = mean_shape
    mean_shape = torch.mean(z_i_a, dim=0)
    iteration += 1  # it is advised in [1] that only two iterations suffice

In [None]:
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
%matplotlib notebook

index = 2

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter(z_i_a[index, :, 0], z_i_a[index, :, 1], z_i_a[index, :, 2])

# Gradient/differential of the PCA in kernel space, orthogonal complement calculation as well

In [None]:
X1 = z_i[0]
X2 = z_i[1]
print(X1.shape)

U, D, V_t = torch.linalg.svd(X1.t() @ X2)
V = V_t.t()
print(" U: ", U.shape, " D: ", torch.diag(D).shape, " V_t: ", V_t.shape)
R_cap = U @ V_t
Beta_cap = torch.trace(torch.diag(D)) / torch.linalg.norm(X1) ** 2

print(torch.linalg.norm(X2 - Beta_cap * X1 @ R_cap))

X1_a = z_i_a[0]
X2_a = z_i_a[1]
print(torch.linalg.norm(X2_a - X1_a))

# Testing out the derivatives dimensions, Approach 1. using SVD...

In [None]:
dU = (torch.eye(U.shape[0]) - U @ U.t()) @ V
dV = (torch.eye(V.shape[0]) - V @ V.t()) @ U
print(dU.shape)
print(dV.shape)

dR = dU @ V_t + U @ dV
print(dR)

# Approach 2. for derivative using Rodrigues formula resulting in a 3-rank tensor 

In [None]:
# will need optimization on this tridiag tensor
ssc = lambda v: torch.tensor([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
RU = lambda A, B: torch.eye(3) + ssc(torch.cross(A, B)) + ssc(torch.cross(A, B))**2 * (1-torch.dot(A,B)) / torch.norm(torch.cross(A,B)) ** 2

# derivative respect to z_c first
dRU = lambda A, B: ssc(A) + 2 * ssc(torch.cross(A, B)) + ssc(A) + ssc(torch.cross(A, B)) ** 2 * (1 / (1 + torch.dot(A, B) ** 2)) * A 

N = z_i.shape[1]  # number of elements
M_til = torch.zeros([3 * N, 3 * N, 3 * N])

id = torch.eye(3)
for j in range(0, 3 * N, 1):
    e_j = id[j % 3]
    for i in range(0, N, 3):
        M_til[i:i+3, i:i+3, j] = dRU(z_i[0, i] * e_j, z_i[5, i])

print(M_til.shape)
# from mpl_toolkits import mplot3d
# import matplotlib.pyplot as plt
# %matplotlib notebook
# plt.imshow(cross)

In [None]:
def zero_volume_boundary(a_volume, a_width):

    a_volume[:a_width, :, :] = 0
    a_volume[-a_width:, :, :] = 0
    a_volume[:, :a_width, :] = 0
    a_volume[:, -a_width:, :] = 0
    a_volume[:, :, :a_width] = 0
    a_volume[:, :, -a_width:] = 0

In [None]:
def flip_vals(A,val1,val2):

    # Find the difference between two values
    diff = val2 - val1

    # Scale masked portion of A based upon the difference value in positive 
    # and negative directions and add up with A to have the desired output
    return A + diff*(A==val1) - diff*(A==val2)

In [None]:
def nonlinear_shape_prior(shape_priors, kernel, sigma, centering_point):
    """
    Nonlinear statistics shape prior based on kernel density estimation in the feature space
        [1] Shape statistics in kernel space for variational image segmentation - Daniel Cremers, Timo Kohlberger,
                                                                                  Christoph Schnoerr
        [2] Active Shape Models - Their Training and Application - T. F. Cootes, C. J. Taylor, D. H. Cooper, J. Graham

    Args:
        z:
        z_i:
        sigma:

    Returns:
        energy:
    """
    m = shape_priors.shape[0]
    
    E = (1 / m) * torch.ones([m, m], dtype=torch.float64)
    K = torch.zeros([m, m], dtype=torch.float64)
    
    height, width, depth = shape_priors[0].shape
    z_i = []
    shape_face_count = torch.zeros([m], dtype=torch.int32)
    shape_faces = []
    for i in range(m):
        verts_shape, tri_shape = mcubes.marching_cubes(shape_priors[i], 0.0)
        cur_prior_shape = verts_shape / depth
        
        # set mesh size to 1 and move it to the centering point
        verts_dist = cdist(cur_prior_shape, cur_prior_shape, 'euclidean')
        verts_scaled = cur_prior_shape * 1.0 / verts_dist.max()
        verts_scaled_translation = centering_point - verts_scaled.mean(axis=0)
        verts_translated = verts_scaled + verts_scaled_translation 

        z_i.append(torch.from_numpy(verts_translated))
        shape_faces.append(tri_shape)
        shape_face_count[i] = tri_shape.shape[0]
     
    min_shape_face_count = shape_face_count.min()
    # if k_til is wrongfully implemented or slow, or numerically unstable, 
    # then one can use K_til̃ = K − KE − EK + EKE
    mean_shape = z_i[ int(m / 2) ] # try it with Wasserstein barycenter here compute the mean shape
    mean_shape_face = shape_faces[ int(m / 2) ] # save the faces as well
    
    
    for i in range(m):
        for j in range(m):
            K[i, j] = kernel(z_i[i], z_i[j], sigma)
    
    K_til= K - K @ E - E @ K + E @ K @ E
    
    # keep only real eigenvalues and eigenvectors
    L, V = torch.linalg.eigh(K_til)
    L = torch.flip(L, [0])
    V = torch.fliplr(V)
    
    limit_val = 1e-6
    if (L <= limit_val).any():
        first_cplx = torch.where(L <= limit_val)[0][0]
        sigma_ort = L[first_cplx - 1] / 2.0
        
        L[first_cplx:] = 0.0
        V[:, first_cplx:] = 0.0
        reg_mx = torch.eye(K.shape[0])
        
        Sigma_ort = V @ torch.diag(L) @ V.t() + sigma_ort * (reg_mx - V @ V.t())
    else:  # bad bad things happen
        first_cplx = -1
        sigma_ort = 1
        Sigma_ort = V @ torch.diag(L) @ V.t()
    
    return z_i, torch.linalg.inv(Sigma_ort), L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, K.sum(), K

In [None]:
def k_til(k, sigma, x, y, z_i, m):
    sum = 0
    for i in range(m):
        sum -= (1 / m) * (k(x, z_i[i], sigma) + k(y, z_i[i], sigma))

    sum += k(x, y, sigma)

    for i in range(m):
        for j in range(m):
                sum += (1 / (m ** 2)) * k(z_i[i], z_i[j], sigma)

    return sum

In [None]:
def E_phi(V, kernel, sigma, z_i, z, L, L_ort, r, m):
    loss = 0.0
    # takes forever to compute it this way
    for k in range(r):
        for i in range(m):
            loss += V[k, i] * (k_til(kernel, sigma, z_i[i], z, z_i, m) ** 2) * (L[k] ** (-1) - L_ort **(-1))
    loss += L_ort * k_til(kernel, sigma, z, z, z_i, m)
    return loss

In [None]:
# alpha_i needs some rescaling, named V[k, i] here
def E_phi_grad(V, kernel, k_matrix_sum, sigma, z_i, z, L, L_ort, r, m):
    loss = torch.zeros(z.shape)
    # optimized gradient computation
    par_z = torch.zeros([m, *z.shape])
    k_til = torch.zeros([m])
    for i in range(m):
        par_z[i] += torch.autograd.grad(kernel(z_i[i], z, sigma), [z])[0]
        k_til[i] += kernel(z_i[i], z, sigma)
        for k in range(m):
            par_z[i] -= (1/m) * torch.autograd.grad(kernel(z, z_i[k], sigma), [z])[0]
            k_til[i] -= (1/m) * (kernel(z, z_i[k], sigma) + kernel(z_i[i], z_i[k], sigma))
    
    k_til += (1 / (m ** 2)) * k_matrix_sum
    
    alpha = cp.copy(V)
    alpha[:, :r] *= (torch.sqrt(L[:r])[:, None]).t()
    
    for k in range(r):
        for i in range(m):
            loss += (alpha[i, k] * k_til[i])  * (alpha[i, k] * par_z[i]) * (L[k] ** (-1) - L_ort ** (-1))
            
    par_zz = torch.zeros([*z.shape])
    for k in range(m):
        par_zz -= (1/m) * torch.autograd.grad(kernel(z, z_i[k], sigma), [z])[0] # multiplication with 2 is missing    
    loss += (L_ort ** (-1)) * par_zz
    
    return 2.0 * loss

In [None]:
# alpha_i needs some rescaling, named V[k, i] here
def E_phi_grad_opt(V, kernel, k_m, k_matrix_sum, sigma, z_i, z, L, L_ort, r, m):
    loss = torch.zeros(z.shape)
    
    # lightspeed optimized gradient computation
    par_z = torch.zeros([m, *z.shape])
    kernel_ = torch.zeros([m])
    for i in range(m):
        par_z[i] = torch.autograd.grad(kernel(z_i[i], z, sigma), [z])[0]
        kernel_[i] = kernel(z_i[i], z, sigma)

    k_til = kernel_ - (1/m) * kernel_.sum(dim=0) - (1 /m) * k_m.sum(dim=1) + (1 / (m ** 2)) * k_matrix_sum
    
    par_z_sum = (1 / m) * par_z.sum(dim = 0)
    kernel_til = lambda par_z, index : par_z[index] - par_z_sum
    
    alpha = cp.copy(V)
    alpha[:, :r] *= (torch.sqrt(L[:r])[:, None]).t()
    
    for k in range(r):
        for i in range(m):
            loss += (alpha[i, k] * k_til[i])  * (alpha[i, k] * kernel_til(par_z, i)) * (L[k] ** (-1) - L_ort ** (-1))
            
    par_zz = torch.zeros([*z.shape])
    for k in range(m):
        par_zz -= (1/m) * par_z[k]
    loss += (L_ort ** (-1)) * par_zz
    
    return 2.0 * loss

In [None]:
from src.util.timer import tic, toc
import time

def recon_preimg(V, kernel, sigma, x_i, x, r, m):
    proj_phi_x = torch.zeros([m])
    for i in range(m):
        proj_phi_x[i] = kernel(x.double(), x_i[i], sigma)
        
    z = torch.rand(x.shape) # initial z for optimization, might need a more clever one...
    z.requires_grad = True
    
    def loss(z):
        sum = 0.0
        for i in range(m):
            sum += ((V ** 2)[i, :] * proj_phi_x * kernel(z, x_i[i], sigma)).sum()
        
        return -2.0 * sum
    
    max_it = 20
    optimizer = torch.optim.LBFGS([z], max_eval=5, max_iter=10, lr=0.5)
    
    history = []
    print("performing reconstruction optimization...")
    start = time.time()

    def closure():
        optimizer.zero_grad()
        L = loss(z.double())
        l = L.detach().cpu().numpy()
        print("loss", l)
        history.append(l)
        L.backward()
        return L

    for i in range(max_it):
        print("it ", i, ": ", end="")
        optimizer.step(closure)

    print("Optimization (L-BFGS) time: ", round(time.time() - start, 2), " seconds")
    return z

# Actual running of the algorithms of shape priors and segmentation with continuous max-flow algorithm

In [None]:
from src.algs.arm import lv_indicator
import matplotlib.pyplot as plt
%matplotlib notebook

# Some thinking is needed in the Mahalanobis distance part
calc_type = torch.float
def segment_left_ventricle(a_volume, a_opt_params, a_algo_params, a_plot=False, a_save_plot=False):
    """
    Main function for segmenting the left ventricle from a reconstructed 3D volume (scalar-field)

    Args:
        a_volume (N, M, K): array_like
                  To be segmented left ventricle SPECT volume (scalar-field)
        
        a_opt_params (dict): dict
                      Parameter pack of the optimization. Upper limit of iterations num_iter,
                      err_bound iteration error between steps limit, scaling for the gradient
                      dampening is gamma, steps is the gradient step size
        
        a_algo_params (dict): dict
                       Parameter pack of the Continuous Max-flow algorithm. TODO
        
        a_plot (bool): bool
                Bit flag to use plotting of intermediate results or not

        a_save_plot (bool): bool
                     Bit flag to save plot of intermediate results or not
    """
    num_iter, err_bound, gamma, steps = a_opt_params.values()
    par_lambda, par_nu, c_zero, c_one, b_zero, b_one, z_i, sigma_inv, L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, k_matrix_sum, k_matrix, kernel = a_algo_params.values()
    m = len(z_i)
    
    norm_epsilon = 0.001

    # mentioned in the paper
    b_zero = c_zero
    b_one = c_one

    volume = a_volume
    density_vol = volume / volume.sum()

    rows, cols, height = a_volume.shape
    im_size = rows * cols * height
    
    # initialization for CMF
    if a_volume.dtype == torch.int32:
        a_volume = a_volume.astype(calc_type)

    alpha = 2 / (par_lambda + par_nu)

    lv_params = dict(a=1, c=2, sigma=-1)
    f_zero = torch.from_numpy(lv_indicator(a_volume, lv_params))
    f_one = f_zero

    im_eff = (par_lambda / (par_lambda + par_nu)) * a_volume + (par_nu / (par_lambda + par_nu)) \
             * (b_zero * (1 - f_one) + b_one * f_one)

    Cs = (im_eff - c_zero) ** 2
    Ct = (im_eff - c_one) ** 2
    
    u_prev = torch.zeros([rows, cols, height])
    u = torch.where(Cs >= Ct, 1, 0).float()

    ps = torch.minimum(Cs, Ct)
    pt = ps

    pp_y = torch.zeros((rows, cols + 1, height), dtype=calc_type)
    pp_x = torch.zeros((rows + 1, cols, height), dtype=calc_type)
    pp_z = torch.zeros((rows, cols, height + 1), dtype=calc_type)
    div_p = torch.zeros((rows, cols, height), dtype=calc_type)

    cmf_iter = 3
    err_iter = torch.zeros(cmf_iter * num_iter, dtype=calc_type)
    norm_u_iter = torch.zeros(num_iter + 1, dtype=calc_type)
    
    if a_plot is True:
        plt.ion()

        figure, axis = plt.subplots(2, 2)
        figure.tight_layout()
        
        slice_num = torch.int32(a_volume.shape[0] / 2)

        plot_obj_vol = axis[0, 0].imshow(a_volume[slice_num, :, :])
        axis[0, 0].set_title("Left Ventricle Volume")

        plot_obj_seg = axis[0, 1].imshow(f_one[slice_num, :, :])
        axis[0, 1].set_title("Segmentation")

        plot_obj_opt = axis[1, 1].imshow(u[slice_num, :, :])
        axis[1, 1].set_title("Optimality")

        plot_obj_err = axis[1, 0].plot(err_iter[0])
        axis[1, 0].set_title("Iteration error")

        plt.show()

    for i in range(num_iter):
        for j in range(cmf_iter):
            pts = div_p - (ps - pt + u / gamma)

            pp_y[:, 1:-1, :] += steps * (pts[:, 1:, :] - pts[:, :-1, :])
            pp_x[1:-1, :, :] += steps * (pts[1:, :, :] - pts[:-1, :, :])
            pp_z[:, :, 1:-1] += steps * (pts[:, :, 1:] - pts[:, :, :-1])

            # the following steps give the projection to make |p(x)| <= alpha(x)
            squares = pp_y[:, :-1, :] ** 2 + pp_y[:, 1:, :] ** 2
            squares += pp_x[:-1, :, :] ** 2 + pp_x[1:, :, :] ** 2
            squares += pp_z[:, :, :-1] ** 2 + pp_z[:, :, 1:] ** 2

            gk = torch.sqrt(squares * .5)
            gk = (gk <= alpha) + torch.logical_not(gk <= alpha) * (gk / alpha)
            gk = 1 / gk

            pp_y[:, 1:-1, :] = (.5 * (gk[:, 1:, :] + gk[:, :-1, :])) * (pp_y[:, 1:-1, :])
            pp_x[1:-1, :, :] = (.5 * (gk[1:, :, :] + gk[:-1, :, :])) * (pp_x[1:-1, :, :])
            pp_z[:, :, 1:-1] = (.5 * (gk[:, :, 1:] + gk[:, :, :-1])) * (pp_z[:, :, 1:-1])

            div_p = pp_y[:, 1:, :] - pp_y[:, :-1, :]
            div_p += pp_x[1:, :, :] - pp_x[:-1, :, :]
            div_p += pp_z[:, :, 1:] - pp_z[:, :, :-1]

            # update the source flow ps
            pts = div_p + pt - u / gamma + 1 / gamma
            ps = torch.minimum(pts, Cs)

            # update the sink flow pt
            pts = -div_p + ps + u / gamma
            pt = torch.minimum(pts, Ct)

            u_error = gamma * (div_p - ps + pt)
            u -= u_error

            u_error_normed = torch.sum(torch.abs(u_error)) / im_size
            err_iter[cmf_iter * i + j] = u_error_normed

            if a_plot is True:
                plot_obj_opt.set_data(u[slice_num, :, :])
                axis[1, 0].plot(err_iter[0: cmf_iter * i + j])
                plt.draw()
        
        norm_u_iter[i + 1] = torch.linalg.norm(u)
        
        c_zero = torch.sum((1 - u) * im_eff) / torch.sum(1 - u)
        c_one = torch.sum(u * im_eff) / (torch.sum(u))

        im_mod = c_zero * (1 - u) + c_one * u

        b_zero = torch.sum((1 - f_one) * im_mod) / (torch.linalg.norm(1 - f_one + norm_epsilon) ** 2)
        b_one = torch.sum(f_one * im_mod) / (torch.linalg.norm(f_one + norm_epsilon) ** 2)
                
        print("u sum: ", u.sum(), "u max: ", u.max(), "u min: ", u.min(), "u count:", (u > 0).sum())
        
        zero_volume_boundary(u, a_width=2)
        #vert_vol, tri_vol = mcubes.marching_cubes(u.numpy(), 0.1)
        vert_vol, tri_vol, _, _ =  measure.marching_cubes(u.numpy(), 0.2)  # 0.5 for cardioc, 0.2 for mph, 0.1 for parallel, 0.1 for cardiod
        cv, nv, cf, nf = pcu.connected_components(vert_vol, tri_vol.astype(np.int32))
        
        num_components = nv.size
        print("Connected components: ", num_components)
        print("Iteration: ", i, "Norm diff: ", torch.abs(norm_u_iter[i + 1] - norm_u_iter[i]))
        
        f_one = torch.zeros((rows, cols, height))
        
        component = 0

        while (component < num_components)  and (torch.abs(norm_u_iter[i + 1] - norm_u_iter[i]) < 0.178) or (i + 1) == num_iter: #   3 for parallel images
            nu = 1e-2
            
            if component >= num_components:
                break

            if num_components > 1:
                component_face_count = nf[component]
            else:
                component_face_count = nf
            
            v_decimate, f_decimate, v_correspondence, f_correspondence =\
                pcu.decimate_triangle_mesh(vert_vol, tri_vol[cf == component].astype(np.int32), min(min_shape_face_count.numpy(), component_face_count))
                        
            z = torch.from_numpy(v_decimate / cols)
            z.requires_grad = True
            
            # renorm to size 1 and translate it to center_point
            z_dist = cdist(z.detach().numpy(), z.detach().numpy(), 'euclidean')
            max_real_size = z_dist.max() * cols
            if max_real_size <= 20:  # dummy "size" selection 15 for parallel geometries, 20 for mph -> sharpen this
                print("Skipping object with diameter: ", max_real_size)
                component = component  + 1
                continue
            
            z_scaled = z * (1.0 / (z_dist.max()))
            z_translation = z_scaled.mean(dim=0) - torch.from_numpy(centering_point)
            
            # Project current shape on the mean shape as in [1]
            pc_fix = PointCloud(mean_shape.detach().numpy(), columns=["x", "y", "z"])
            pc_mov = PointCloud((z_scaled - z_translation).detach().numpy(), columns=["x", "y", "z"])
            icp = SimpleICP()
            icp.add_point_clouds(pc_fix, pc_mov)
            H, proj_mean_icp, rigid_body_transformation_params, distance_residuals = icp.run(max_overlap_distance=1)
            # add reorientation based registration here
            
            proj_mean = torch.from_numpy(proj_mean_icp)
            proj_mean.requires_grad = True
            
            grad_E = E_phi_grad_opt(V, kernel, k_matrix, k_matrix_sum, sigma, z_i, proj_mean, L, sigma_ort, first_cplx, m)
            
            # the last terms in the gradient calculation
            # d til_z / d_z_c * d_z_c / d_z
            Rot = H[:-1, :-1]
            translation = H[:-1, -1]
            
            it_shape = ((proj_mean - nu * grad_E - torch.from_numpy(translation)) @ torch.from_numpy(Rot) + z_translation) * z_dist.max()
            print("Translation:", translation, "Normalization factor: ", z_dist.max(), "Mean translation: ", z_translation)
                        
            import matplotlib.pyplot as plt
            %matplotlib notebook        
            fig = plt.figure()
            ax1 = fig.add_subplot(1, 2, 1, projection='3d')
            plot_1 = z.detach().numpy()
            ax1.scatter(plot_1[:, 0], plot_1[:, 1], plot_1[:, 2])
            ax2 = fig.add_subplot(1, 2, 2, projection='3d')
            plot_2 = it_shape.detach().numpy()
            ax2.shareview(ax1)
            ax2.scatter(plot_2[:, 0], plot_2[:, 1], plot_2[:, 2])
            # 
            # component = component + 1
            # continue
            
            # voxelization and bounds checking
            ijk = pcu.voxelize_triangle_mesh((it_shape).detach().numpy() * cols, f_decimate.astype(np.int32), 1, [0., 0., 0.])
            ijk = ijk[np.sum(np.logical_and(ijk >=0, ijk < cols), axis=1) == 3, :]
            ijk = ijk[ijk[:, 0] < rows] # more likely that the axial dim is different
            
            f_one[ijk[:, 0], ijk[:, 1], ijk[:, 2]] = 1

            print("component: ", component, " energy min: ", grad_E.min(), " energy max: ", grad_E.max(),
                  " shape prior count: ", f_one.sum(), " shape prior mean pos:", (it_shape).mean())
            
            component = component + 1
    
        if a_plot is True:
            plot_obj_seg.set_data(f_one[slice_num, :, :])
            plt.draw()

            if a_save_plot is True:
                name = "..\\..\\left_ventricle_" + str(i) + ".png"
                plt.savefig(name, bbox_inches='tight', pad_inches=0)

        im_eff = (par_lambda / (par_lambda + par_nu)) * a_volume + (par_nu / (par_lambda + par_nu)) \
                 * (b_zero * (1 - f_one) + b_one * f_one)
        
        H = torch.where(u > 0, 1, 0)
        Cs = (im_eff - c_zero) ** 2 * torch.kl_div(H * density_vol, (1 - H) * density_vol).sum()
        Ct = (im_eff - c_one) ** 2 * torch.kl_div((1 - H) * density_vol, H * density_vol).sum()
    
    return u, err_iter, num_iter, f_one

In [None]:
# might need to try with different models from CV, e.g.: MS, Potts
from geomloss import SamplesLoss
eps = 5 * 1e-3
loss_unbalanced = SamplesLoss(loss='sinkhorn', p=2, blur=eps, scaling=0.95)
sigma = 5 * 1e0
# k = lambda x, y, sigma : torch.exp(-loss(x, y) ** 2 / (2 * sigma ** 2))
k = lambda x, y, sigma : torch.exp(-sigma * loss_unbalanced(x, y))
centering_point = np.array([0.45, 0.45, 0.45])

z_i, sigma_inv, L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, k_matrix_sum, k_matrix  = nonlinear_shape_prior(shape_priors, kernel=k, sigma=sigma, centering_point=centering_point)

In [None]:
print(V.max())

In [None]:
import pyprof
import torch.cuda.profiler as profiler 

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    with_flops=True) as prof:
        opt_params = dict(num_iter=10, err_bound=0, gamma=1e-2, steps=1e-1)
        cmf_params = dict(par_lambda=2, par_nu=3, c_zero=0.3, c_one=0.7, b_zero=1e-1, b_one=1e1,
                          z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
        lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)

events = prof.events()
cmf_shape_flops = sum([int(evt.flops) for evt in events]) 
print("Runtime FLOPs: ", cmf_shape_flops)

In [None]:
print(prof)
events = prof.events()
cmf_shape_flops = sum([int(evt.flops) for evt in events]) 
print("Runtime FLOPs: ", cmf_shape_flops)

In [None]:
import matplotlib.pyplot as plt

%matplotlib notebook
slice = 30

fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(lam[slice, :, :])
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(lv_volume[slice, :, :])
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(lam_shape_prior[slice, :, :])
plt.show()

In [None]:
# inspection figures in the paper
fill_value = 3  # Fill starting from [0, 0, 0] with the value 2
label_prior = cp.copy(lam_shape_prior)
filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)

import matplotlib.pyplot as plt

%matplotlib notebook
slice = 30

fig = plt.figure()
ax1 = fig.add_subplot(1, 2, 1)
ax1.imshow(filled_myocard[slice, :, :])
ax1 = fig.add_subplot(1, 2, 2)
ax1.imshow(lv_volume[slice, :, :])

In [None]:
# fig = plt.figure()
# fig.set_size_inches((1, 1))
# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.set_axis_off()
# fig.add_axes(ax)
# # turkey pat 10, (TRA 30, 17:37, 30:50), (VLA 23:43, 28, 30:50), (SA 23:43, 17:37, 42)
# 
# plt.imshow(lv_volume[23:43, 17:37, 42], aspect='equal')
# plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/" + "patient_10_sa" + ".png", bbox_inches='tight', pad_inches=0)
# plt.close()

## Reconstructing one of the eigenvalues/eigenvectors to check the correctness of embedding

In [None]:
from geomloss import SamplesLoss
eps = 5 * 1e-1
loss = SamplesLoss(loss='sinkhorn', p=2, blur=eps)
sigma = 5 * 1e-5
k = lambda x, y, sigma : torch.exp(-loss(x, y) ** 2 / (2 * sigma ** 2))

z_i, sigma_inv, L, V, sigma_ort, _, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, k_matrix_sum, k_matrix = nonlinear_shape_prior(shape_priors, kernel=k, sigma=sigma, centering_point=centering_point)

In [None]:
# input_shape = torch.rand([400, 3])
verts, faces = mcubes.marching_cubes(lv_volume, 0.5)
v_decimate, f_decimate, v_correspondence, f_correspondence = pcu.decimate_triangle_mesh(verts, faces.astype(np.int32), min_shape_face_count)
input_shape = torch.from_numpy(v_decimate / height)
print(loss(input_shape.double(), z_i[0]))
print(loss(z_i[1], z_i[-1]))

In [None]:
sigma = 5 * 1e-1
pre_z = recon_preimg(V, k, sigma, z_i, input_shape, first_cplx, len(z_i)).detach().numpy()

In [None]:
# visualization
import matplotlib.pyplot as plt
%matplotlib notebook

model_ind = 0
model = z_i[model_ind].detach().numpy()

input = input_shape.detach().numpy()

fig = plt.figure()

ax1 = fig.add_subplot(1, 3, 1, projection='3d')
ax1.scatter(pre_z[:, 0], pre_z[:, 1], pre_z[:, 2])

ax2 = fig.add_subplot(1, 3, 2, projection='3d')
ax2.shareview(ax1)
ax2.scatter(input[:, 0], input[:, 1], input[:, 2])

ax3 = fig.add_subplot(1, 3, 3, projection='3d')
ax3.shareview(ax1)
ax3.scatter(model[:, 0], model[:, 1], model[:, 2])

print("Reconstructed surface shape: ", pre_z.shape, " Input shape: ", input_shape.shape)
plt.show()

In [None]:
# visualization
import matplotlib.pyplot as plt
%matplotlib notebook

recon_model_plot = pre_z * width

fig = plt.figure()

ax1 = fig.add_subplot(1, 1, 1, projection='3d')
ax1.scatter(recon_model_plot[:, 0], recon_model_plot[:, 1], recon_model_plot[:, 2])
plt.show()

## Check optimal transport projection

In [None]:
lv_params = dict(a=1, c=2, sigma=-1)
vol = lv_indicator(np.zeros([64, 64, 64]), lv_params)
input_shape_verts, input_shape_faces = mcubes.marching_cubes(vol, 0.0)

In [None]:
x_shape = torch.from_numpy(input) / 64 # torch.from_numpy(input_shape_verts / 64)
y_shape = mean_shape

print("Source shape mean:", x_shape.mean(), "Target shape mean:", y_shape.mean())
translation = y_shape.mean(dim=0) - x_shape.mean(dim=0)
print("Translation:", translation)
print("Source shape mean with translation:", (x_shape + translation).mean())

In [None]:
print(torch.linalg.norm(x_shape))
print(torch.linalg.norm(y_shape))   

In [None]:
x = (x_shape).double().t().flatten()[:, None] / torch.linalg.norm(x_shape)
y = y_shape.double().flatten()[:, None] / torch.linalg.norm(y_shape)

a = torch.ones(x.shape)
b = torch.ones(y.shape)

print("Source mean: ", x.mean(), "Target mean: ", y.mean())
print(x.shape, y.shape)

N, M, D = x.shape[0], y.shape[0], x.shape[1]
p = 2
blur = 5 * 1e-2

OT_solver = SamplesLoss(loss = "sinkhorn", p = p, blur = blur, scaling=0.2, debias=False, potentials=True)
F, G = OT_solver(x, y)  # Dual potentials

x_i, y_j = x.view(N, 1, D), y.view(1, M, D)
a_i, b_j = a.view(N, 1), b.view(1, M)
F_i, G_j = F.view(N, 1), G.view(1, M)

C_ij = (1 / p) * ((x_i - y_j) ** p).sum(-1)
eps = blur ** p
P_ij = ((F_i + G_j - C_ij) / eps).exp()

proj_mean = x.flatten() @ (P_ij)
proj_mean = torch.reshape(proj_mean, mean_shape.shape)

In [None]:
torch.norm(mean_shape - proj_mean)

In [None]:
# try estimating the rotation mx with wasserstein metric
x = x_shape.double()
y = y_shape.double()

In [None]:
from scipy.spatial.distance import cdist

x_dist = cdist(x.numpy(), x.numpy(), 'euclidean')
y_dist = cdist(y.numpy(), y.numpy(), 'euclidean')

print(x.max(), y.max())
print(x.min(), y.min())
print(x.mean(), y.mean())
print(x_dist.max(), y_dist.max())

In [None]:
x_scale = 1.0 / x_dist.max()
y_scale = 1.0 / y_dist.max()

print(x_scale)
print(y_scale)

In [None]:
x_scaled = x * x_scale
y_scaled = y * y_scale

x_scaled_mean = x_scaled.mean(dim=0)
y_scaled_mean = y_scaled.mean(dim=0)

x_translation = torch.tensor([0.45, 0.45, 0.45]) - x_scaled_mean
y_translation = torch.tensor([0.45, 0.45, 0.45]) - y_scaled_mean

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()

ax1 = fig.add_subplot(1, 2, 1, projection='3d')
plot_input_shape_verts = x_scaled + x_translation
ax1.scatter(plot_input_shape_verts[:, 0], plot_input_shape_verts[:, 1], plot_input_shape_verts[:, 2])

ax2 = fig.add_subplot(1, 2, 2, projection='3d')
ax2.shareview(ax1)
output_shape_verts = (y_scaled + y_translation).detach().numpy()
ax2.scatter(output_shape_verts[:, 0], output_shape_verts[:, 1], output_shape_verts[:, 2])

plt.show()

In [None]:
# try out ICP
from simpleicp import PointCloud, SimpleICP

pc_fix = PointCloud((y_scaled + y_translation).detach().numpy(), columns=["x", "y", "z"])
pc_mov = PointCloud(x_scaled + x_translation, columns=["x", "y", "z"])

# Create simpleICP object, add point clouds, and run algorithm!
icp = SimpleICP()
icp.add_point_clouds(pc_fix, pc_mov)
H, proj_mean_icp, rigid_body_transformation_params, distance_residuals = icp.run(max_overlap_distance=1)

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()

ax1 = fig.add_subplot(1, 3, 1, projection='3d')
orig_input_shape = x.detach().numpy()
ax1.scatter(orig_input_shape[:, 0], orig_input_shape[:, 1], orig_input_shape[:, 2])

ax2 = fig.add_subplot(1, 3, 2, projection='3d')
ax2.shareview(ax1)
plot_input_shape_verts = proj_mean_icp
ax2.scatter(plot_input_shape_verts[:, 0], plot_input_shape_verts[:, 1], plot_input_shape_verts[:, 2])

ax3 = fig.add_subplot(1, 3, 3, projection='3d')
ax3.shareview(ax1)
output_shape_verts = (y_scaled + y_translation).detach().numpy()
ax3.scatter(output_shape_verts[:, 0], output_shape_verts[:, 1], output_shape_verts[:, 2])
plt.show()

In [None]:
print(H)

In [None]:
# test inverse rotation
Rot = H[:-1, :-1]
print(Rot)
inv_Rot = np.linalg.inv(Rot)
print(inv_Rot)
translation = H[:-1, -1]
print(translation)

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()
ax1 = fig.add_subplot(1, 2, 1, projection='3d')
inv_rot_input_shape = x.detach().numpy()
ax1.scatter(inv_rot_input_shape[:, 0], inv_rot_input_shape[:, 1], inv_rot_input_shape[:, 2])

ax3 = fig.add_subplot(1, 2, 2, projection='3d')
ax3.shareview(ax1)
output_shape_verts = x.detach().numpy()
ax3.scatter(output_shape_verts[:, 0], output_shape_verts[:, 1], output_shape_verts[:, 2])

plt.show()

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()

ax1 = fig.add_subplot(1, 3, 1, projection='3d')
plot_input_shape_verts = (x_shape + translation) * 64
ax1.scatter(plot_input_shape_verts[:, 0], plot_input_shape_verts[:, 1], plot_input_shape_verts[:, 2])

ax2 = fig.add_subplot(1, 3, 2, projection='3d')
ax2.shareview(ax1)
output_shape_verts = proj_mean.detach().numpy()
ax2.scatter(output_shape_verts[:, 0], output_shape_verts[:, 1], output_shape_verts[:, 2])

ax3 = fig.add_subplot(1, 3, 3, projection='3d')
ax3.shareview(ax1)
mean_shape_verts = mean_shape * 64
ax3.scatter(mean_shape_verts[:, 0], mean_shape_verts[:, 1], mean_shape_verts[:, 2])

plt.show()

### Based on chosen OT check theoretical facts

In [None]:
def kernel_matrix(z_i, kernel, sigma):
    m = len(z_i)
    
    K = torch.zeros([m, m])
    E = (1 / m) * torch.ones([m, m])
    
    for i in range(m):
        for j in range(m):
            K[i, j] = kernel(z_i[i], z_i[j], sigma)
    K_til= (torch.eye(m) - E) @ K @ (torch.eye(m) - E)
    
    L, V = torch.linalg.eigh(K_til)
    L = torch.flip(L, [0])
    V = torch.fliplr(V)

    limit_val = 1e-6
    if (L <= limit_val).any():
        first_cplx = torch.where(L <= limit_val)[0][0]
        sigma_ort = L[first_cplx - 1] / 2.0
        
        L[first_cplx:] = 0.0
        V[:, first_cplx:] = 0.0
        reg_mx = torch.eye(K.shape[0])
    
    Sigma_ort = V @ torch.diag(L) @ V.t() + sigma_ort * (reg_mx - V @ V.t())
        
    return K_til, K, Sigma_ort, V, L, K.sum(), first_cplx, sigma_ort

In [None]:
from geomloss import SamplesLoss

eps = 5 * 1e-3

# balanced and unbalanced loss fns
loss_balanced = SamplesLoss(loss='sinkhorn', p=2, blur=eps)
loss_unbalanced = SamplesLoss(loss='sinkhorn', p=2, blur=eps, reach=0.7, scaling=0.95)

sigma = 5 * 1e0
k_balanced = lambda x, y, sigma : torch.exp(-sigma * loss_balanced(x, y))
k_unbalanced = lambda x, y, sigma : torch.exp(-sigma * loss_unbalanced(x, y))

K_til_balanced, K_balanced, Sigma_balanced, V_balanced, L_balanced, K_sum_balanced, first_cplx_balanced, sigma_ort_balanced = kernel_matrix(z_i, k_balanced, sigma)
K_til_unbalanced, K_unbalanced, Sigma_unbalanced, V_unbalanced, L_unbalanced, K_sum_unbalanced, first_cplx_unbalanced, sigma_ort_unbalanced = kernel_matrix(z_i, k_unbalanced, sigma)

In [None]:
L_bal = torch.linalg.cholesky(K_til_balanced)

In [None]:
L_ubal = torch.linalg.cholesky(K_til_unbalanced)

In [None]:
import matplotlib.pyplot as plt

%matplotlib notebook

fig = plt.figure()
ax1 = fig.add_subplot(2, 3, 1)
ax1.imshow(K_til_balanced)
ax2 = fig.add_subplot(2, 3, 2)
ax2.imshow(K_balanced)
ax3 = fig.add_subplot(2, 3, 3)
ax3.imshow(Sigma_balanced)

ax4 = fig.add_subplot(2, 3, 4)
ax4.imshow(K_til_unbalanced)
ax5 = fig.add_subplot(2, 3, 5)
ax5.imshow(K_unbalanced)
ax6 = fig.add_subplot(2, 3, 6)
ax6.imshow(Sigma_unbalanced)

plt.show()

In [None]:
L_bal_, V_bal_ = torch.linalg.eigh(K_til_balanced)
L_bal = torch.flip(L_bal_, [0])
V_bal = torch.flip(V_bal_, [0, 1])
print(L)

L_ubal_, V_ubal_ = torch.linalg.eigh(K_til_unbalanced)
L_ubal = torch.flip(L_ubal_, [0])
V_ubal = torch.flip(V_ubal_, [0, 1])
print(L)

In [None]:
for i in range(len(z_i)):
    print(V_bal[:, i] @ V_bal[:, i])

In [None]:
print(torch.linalg.det(K_til_balanced))
print(torch.linalg.det(K_balanced))
print(torch.linalg.det(Sigma_balanced))
print(torch.linalg.det(K_til_unbalanced))
print(torch.linalg.det(K_unbalanced))
print(torch.linalg.det(Sigma_unbalanced))

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(V_bal_)
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(V_bal)
ax3 = fig.add_subplot(1, 3, 3)
ax3.imshow(torch.fliplr(V_bal_))

### Check gradient, mean projection and test point projection as well

In [None]:
import time

z = torch.from_numpy(input) # torch.from_numpy(mean_shape / 64).double()
z.requires_grad = True

start = time.time()
grad_E = E_phi_grad(V_unbalanced, k_unbalanced, K_sum_unbalanced, 5 * 1e0, z_i, z, L_unbalanced, sigma_ort_unbalanced, first_cplx_unbalanced, len(z_i))
print("Naive gradient calculation time: ", round(time.time() - start, 2), " seconds")

start = time.time()
grad_E_opt = E_phi_grad_opt(V_unbalanced, k_unbalanced, K_unbalanced, K_sum_unbalanced, 5 * 1e0, z_i, z, L_unbalanced, sigma_ort_unbalanced, first_cplx_unbalanced, len(z_i))
print("Optimized gradient calculation time: ", round(time.time() - start, 2), " seconds")

In [None]:
print("Lazy gradient max: ", grad_E.max(), "Lazy gradient min: ", grad_E.min())
print("Opt gradient max: ", grad_E_opt.max(), "Opt gradient min: ", grad_E_opt.min())
print("Lazy grad norm: ", torch.linalg.norm(grad_E))
print("Opt grad norm: ", torch.linalg.norm(grad_E_opt))
print("Norm difference: ", torch.linalg.norm(grad_E - grad_E_opt))

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

nu = 1e0

it_shape = z * 64 - nu * grad_E # z * 64 - grad_E #torch.from_numpy(input).double() # proj_mean * 64 - grad_E

fig = plt.figure()
ax1 = fig.add_subplot(1, 1, 1, projection='3d')
plot = (it_shape).detach().numpy()
ax1.scatter(plot[:, 0], plot[:, 1], plot[:, 2])

In [None]:
source_shape = z.detach().numpy()
print(source_shape.max(), source_shape.min(), source_shape.mean())

In [None]:
result_shape = it_shape.detach().numpy()
print(result_shape.max(), result_shape.min(), result_shape.mean())

In [None]:
(input.mean())

In [None]:
(input_shape_verts / 64).mean()

In [None]:
(mean_shape).mean()

In [None]:
(z).mean()

## Now scalarize the model and check variability of statistical model

In [None]:
import pymeshlab
pc = pymeshlab.MeshSet()
pc.add_mesh(pymeshlab.Mesh(recon_model_plot))
pc.compute_normal_for_point_clouds()
pc.generate_surface_reconstruction_ball_pivoting()
mesh = pc.current_mesh()

faces = mesh.face_matrix()
f_one = torch.zeros(lv_volume.shape)
ijk = pcu.voxelize_triangle_mesh(recon_model_plot, faces.astype(np.int32), 1, [0., 0., 0.])

f_one[ijk[:, 0], ijk[:, 1], ijk[:, 2]] = 1
print(f_one.sum())

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_trisurf(recon_model_plot[:, 0], recon_model_plot[:, 1], recon_model_plot[:, 2], triangles = faces.astype(np.int32), edgecolor=[[0,0,0]], linewidth=1.0, alpha=0.0, shade=False)

In [None]:
pts = np.random.rand(9000, 3) * (recon_model_plot.max(0) - recon_model_plot.min(0)) + recon_model_plot.min(0)

sdfs, face_ids, barycentric_coords = pcu.signed_distance_to_mesh(pts.astype(np.float32), recon_model_plot, faces.astype(np.int32))
sdf_vol = np.zeros([64, 64, 64])
sdf_vol[pts[:, 0].astype(np.int32), pts[:, 1].astype(np.int32), pts[:, 2].astype(np.int32)] = sdfs

In [None]:
print(sdfs.max())

In [None]:
import matplotlib.pyplot as plt
slice = 55

fig = plt.figure()
plt.imshow(f_one[slice, :, :])
plt.show()

## Reconstruct model parameters from measurement

In [None]:
# takes too long, running an EM algorithm here
recon_coords = torch.from_numpy(pre_z)
volume = np.zeros([*lv_volume.shape])

sigma = 5 * 1e0

# parameter limits
# wall_thickness = np.random.uniform(0.3, 1.0, num_prior)
# rot_angles = np.random.uniform(0, 2 * np.pi, num_prior)
# curvature = np.random.uniform(1.5, 3, num_prior)
# sigmas = np.random.uniform(-0.5, -1, num_prior)

points = torch.zeros([len(shape_priors), 3])
N = len(shape_priors)
data_size = int(N * N)
euclidean_dist = torch.zeros(data_size)
feature_dist = torch.zeros(data_size)

for i in range(len(shape_priors)):
    points[i] = torch.tensor([wall_thickness[i], curvature[i], sigmas[i]])
    
for i in range(len(shape_priors)):
    for j in range(len(shape_priors)):
        euclidean_dist[i * (len(shape_priors)) + j] = torch.cdist(points[i, None], points[j, None], p=2)
        feature_dist[i * (len(shape_priors)) + j] = k(z_i[i], z_i[j], sigma)
            
from scipy import interpolate

f = interpolate.interp1d(feature_dist, euclidean_dist, fill_value='extrapolate')

## Now estimate coordinates based on feature space distance

In [None]:
feature_dist_recon = torch.zeros(N)
euclidean_dist_recon = torch.zeros(N)
for i in range(N):
    feature_dist_recon[i] = k(recon_coords.double(), z_i[i].double(), sigma)
    euclidean_dist_recon[i] = torch.from_numpy(f(feature_dist_recon[i]))

euclidean_dist_all = torch.cat([euclidean_dist, euclidean_dist_recon])
feature_dist_all = torch.cat([feature_dist, feature_dist_recon])
sorted, indices = torch.sort(feature_dist_all, 0)

print(euclidean_dist_all[indices].shape)
print(feature_dist.shape)
# print(euclidean_dist_all)

import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
plt.plot(feature_dist, euclidean_dist, 'o', sorted, euclidean_dist_all[indices], '-')
plt.show()

In [None]:
print(feature_dist_recon)

In [None]:
import math
import scipy.optimize

def func(par):
    x1, x2, x3 = par
    eqs = torch.zeros(N)
    for i in range(N):
        a = points[i, 0]
        b = points[i, 1]
        c = points[i, 2]
        
        eqs[i] = (x1 - a) ** 2 + (x2 - b) ** 2 + (x3 - c) ** 2
    
    return eqs

def system(x, b):
    return (func(x) - b ** 2)

x = scipy.optimize.leastsq(system, np.asarray((0.6, 2.0, -0.7)), args=(euclidean_dist_recon), full_output=True)[0]

print(x)

## Testing the gradient calculation based on scalarized model

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_trisurf(verts[:, 0], verts[:, 1], verts[:, 2], triangles = faces.astype(np.int32), edgecolor=[[0,0,0]], linewidth=1.0, alpha=0.0, shade=False)

In [None]:
input_shape.requires_grad = True
grad_E = E_phi_grad(V, k, k_matrix_sum, sigma, z_i, input_shape.double(), L, sigma_ort, first_cplx, len(z_i)).detach().numpy()

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter(grad_E[:, 0], grad_E[:, 1], grad_E[:, 2])

plt.show()

In [None]:
plt.close("all")

In [None]:
# import matplotlib.pyplot as plt
# %matplotlib notebook        
# fig = plt.figure()
# ax1 = fig.add_subplot(1, 1, 1, projection='3d')
# plot = (it_shape).detach().numpy()
# ax1.scatter(plot[:, 0], plot[:, 1], plot[:, 2])

# Computing metric for the black-box mixed dataset

In [None]:
def compute_metrics(prediction, target):
    epsilon=1e-9
    
    num_runs = len(prediction)
    
    precisions = np.zeros(num_runs)
    recalls = np.zeros(num_runs)
    ious = np.zeros(num_runs)
    dice_scores = np.zeros(num_runs)
        
    for i in range(len(prediction)):
        pred = torch.from_numpy(prediction[i])
        targ = torch.from_numpy(target[i])
        
        p1 = 1 - pred
        g1 = 1 - targ
        
        tp = (targ * pred).sum()
        fp = (pred * g1).sum()
        fn = (p1 * targ).sum()
        
        precision = (tp / (tp + fp))
        precisions[i] = precision
        
        recall = (tp / (tp + fn))
        recalls[i] = recall
        
        iou = (tp / (tp + fp + fn))
        ious[i] = iou
        
        dice_score = ((2 * tp) / (2 * tp + fp + fn + epsilon))
        dice_scores[i] = dice_score
    
    return precisions, recalls, ious, dice_scores

In [None]:
# predictions = []
# targets = []
# 
# for i in range(len(subjects)):
#     opt_params = dict(num_iter=11, err_bound=0, gamma=1e-2, steps=1e-1)
#     cmf_params = dict(par_lambda=1.5, par_nu=0.7, c_zero=0.1, c_one=0.35, b_zero=1e-1, b_one=1e1,
#                   z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
#     lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(subjects[i]['spect']), a_opt_params=opt_params, a_algo_params=cmf_params)
#     predictions.append(lam_shape_prior)
#     targets.append(subjects[i]['left_ventricle'])

In [None]:
# compute_metrics(predictions, targets)

In [None]:
# print(i)

In [None]:
# print(type(predictions), type(targets))

## Computing left ventricles on MPH images

In [None]:
print(len(subjects_bela)) # 0 - 2 is MPH images

In [None]:
mph_vol_0 = subjects_bela[15]['spect']
mph_lab_0 = subjects_bela[15]['left_ventricle']
mph_vol_1 = subjects_bela[16]['spect']
mph_lab_1 = subjects_bela[16]['left_ventricle']
mph_vol_2 = subjects_bela[17]['spect']
mph_lab_2 = subjects_bela[17]['left_ventricle']
mph_vol_3 = subjects_bela[18]['spect']
mph_lab_3 = subjects_bela[18]['left_ventricle']
mph_vol_4 = subjects_bela[19]['spect']
mph_lab_4 = subjects_bela[19]['left_ventricle']
mph_vol_5 = subjects_bela[20]['spect']
mph_lab_5 = subjects_bela[20]['left_ventricle']
mph_vol_6 = subjects_bela[21]['spect']
mph_lab_6 = subjects_bela[21]['left_ventricle']
mph_vol_7 = subjects_bela[22]['spect']
mph_lab_7 = subjects_bela[22]['left_ventricle']
mph_vol_8 = subjects_bela[23]['spect']
mph_lab_8 = subjects_bela[23]['left_ventricle']
mph_vol_9 = subjects_bela[24]['spect']
mph_lab_9 = subjects_bela[24]['left_ventricle']

In [None]:
# check sizes
print(mph_vol_0.shape, mph_vol_1.shape, mph_vol_2.shape)

In [None]:
targets_mph = []
predictions_mph = []
predictions_mph_myocard = []

for i in range(15, 25):
    lv_volume = subjects_bela[i]['spect']
    lv_lab = subjects_bela[i]['left_ventricle']
    
    opt_params = dict(num_iter=10, err_bound=0, gamma=1e-2, steps=1e-1)
    cmf_params = dict(par_lambda=1.5, par_nu=0.7, c_zero=0.2, c_one=0.21, b_zero=1e-1, b_one=1e1,
                      z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
    lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)
    
    targets_mph.append(lv_lab)
    
    # fill
    fill_value = 2
    label_prior = cp.copy(lam_shape_prior)
    filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
    filled_myocard = np.where( filled_myocard <= 1, 1, 0)
    pred_myocard = np.where( filled_myocard == 1, lam, 0)
    
    predictions_mph_myocard.append(filled_myocard)
    predictions_mph.append(pred_myocard)

In [None]:
precisions, recalls, ious, dice_scores = compute_metrics(predictions_mph_myocard, targets_mph)
print(precisions[1:].mean(), recalls.mean(), ious.mean(), dice_scores.mean())
print(precisions[1:].std(), recalls.std(), ious.std(), dice_scores.std()) 

In [None]:
# inspection figures in the paper
fill_value = 2  # Fill starting from [0, 0, 0] with the value 2
label_prior = cp.copy(lam_shape_prior)
filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)

import matplotlib.pyplot as plt

%matplotlib notebook
slice = 10

fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(lam[slice, :, :])
ax1 = fig.add_subplot(1, 3, 2)
ax1.imshow(lam_shape_prior[slice, :, :])
ax1 = fig.add_subplot(1, 3, 3)
ax1.imshow(lv_volume[slice, :, :])

In [None]:
# fig = plt.figure()
# fig.set_size_inches((1, 1))
# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.set_axis_off()
# fig.add_axes(ax)
# patient_3_mph_rest_8min, (TRA 32, 55:85, 55:85), (VLA 20:50, 73, 55:85), (SA 20:50, 55:85, 70)
# patient_2_mph_stress_8_min, (TRA 20, 45:75, 50:80), (VLA 5:35, 60, 50:80), (SA 5:35, 45:75, 67)
# patient_1_mph_stress_fp, (TRA ), (VLA ), (SA )

# plt.imshow(lv_volume[5:35, 45:75, 67], aspect='equal')
# plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/" + "patient_2_mph_stress_8_min_dat_sa" + ".png", bbox_inches='tight', pad_inches=0)
# plt.close()

In [None]:
plt.close("all")

In [None]:
# compute segmented cardiac volume parameters
rows, cols, height = filled_myocard.shape
verts, faces = mcubes.marching_cubes(filled_myocard, 0.0)
v_decimate, f_decimate, v_correspondence, f_correspondence = pcu.decimate_triangle_mesh(verts, faces.astype(np.int32),
                                                                                        min_shape_face_count)
input_shape = torch.from_numpy(v_decimate / cols)

In [None]:
sigma = 5 * 1e0
pre_z = recon_preimg(V, k, sigma, z_i, input_shape, first_cplx, len(z_i)).detach().numpy()

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.scatter(pre_z[:, 0], pre_z[:, 1], pre_z[:, 2])

plt.show()

In [None]:
mph_params = []

for i in range(len(predictions_mph)):
    lv_pred = predictions_mph[i]
    
    if lv_pred.max() > 0:
        par = recon_model_params(predictions_mph[i])
        mph_params.append(par)

In [None]:
print(mph_params)
file = open('/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/measurements/recon_params_mph.txt', 'w')
for param in mph_params:
	file.write(np.array2string(param))
file.close()

## Computing left ventricles for CardioC images

In [None]:
cardioc_vol_0 = subjects_bela[0]['spect']
cardioc_lab_0 = subjects_bela[0]['left_ventricle']
cardioc_vol_1 = subjects_bela[1]['spect']
cardioc_lab_1 = subjects_bela[1]['left_ventricle']
cardioc_vol_2 = subjects_bela[2]['spect']
cardioc_lab_2 = subjects_bela[2]['left_ventricle']
cardioc_vol_3 = subjects_bela[3]['spect']
cardioc_lab_3 = subjects_bela[3]['left_ventricle']
cardioc_vol_4 = subjects_bela[4]['spect']
cardioc_lab_4 = subjects_bela[4]['left_ventricle']
cardioc_vol_5 = subjects_bela[-9]['spect']
cardioc_lab_5 = subjects_bela[-9]['left_ventricle']
cardioc_vol_6 = subjects_bela[-4]['spect']
cardioc_lab_6 = subjects_bela[-4]['left_ventricle']
cardioc_vol_7 = subjects_bela[-3]['spect']
cardioc_lab_7 = subjects_bela[-3]['left_ventricle']
cardioc_vol_8 = subjects_bela[-2]['spect']
cardioc_lab_8 = subjects_bela[-2]['left_ventricle']
cardioc_vol_9 = subjects_bela[-1]['spect']
cardioc_lab_9 = subjects_bela[-1]['left_ventricle']

lv_volume = cardioc_vol_0

cardioc_labs = []
for i in [0, 1, 2, 3, 4, -9, -4, -3, -2, -1]:
    cardioc_labs.append(subjects_bela[i]['left_ventricle'])

In [None]:
targets_cardioc = []
predictions_cardioc = []
predictions_cardioc_myocard = []

for i in [0, 1, 2, 3, 4, -9, -4, -3, -2, -1]:
    lv_volume = subjects_bela[i]['spect']
    lv_lab = subjects_bela[i]['left_ventricle']
    
    opt_params = dict(num_iter=14, err_bound=0, gamma=1e-2, steps=1e-1)
    cmf_params = dict(par_lambda=1.0, par_nu=0.7, c_zero=0.2, c_one=0.25, b_zero=1e-1, b_one=1e1,
                      z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
    lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)
    
    targets_cardioc.append(lv_lab)
    
    # fill
    fill_value = 2
    label_prior = cp.copy(lam_shape_prior)
    filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
    filled_myocard = np.where( filled_myocard <= 1, 1, 0)
    pred_myocard = np.where( filled_myocard == 1, lam, 0)
    
    predictions_cardioc_myocard.append(filled_myocard)
    predictions_cardioc.append(pred_myocard)

In [None]:
precisions, recalls, ious, dice_scores = compute_metrics(predictions_cardioc, targets_cardioc)
print(precisions.mean(), recalls.mean(), ious.mean(), dice_scores.mean())
print(precisions.std(), recalls.std(), ious.std(), dice_scores.std())

In [None]:
# inspection figures in the paper
fill_value = 2  # Fill starting from [0, 0, 0] with the value 2
label_prior = cp.copy(lam_shape_prior)
filled_myocard = pcu.flood_fill_3d(label_prior, [20, 62, 72], fill_value)

import matplotlib.pyplot as plt

%matplotlib notebook
slice = 20

fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(lam[slice, :, :])
ax1 = fig.add_subplot(1, 3, 2)
ax1.imshow(lam_shape_prior[slice, :, :])
ax1 = fig.add_subplot(1, 3, 3)
ax1.imshow(lv_volume[slice, :, :])

In [None]:
print(V.max())

In [None]:
cardioc_params = []

for i in range(len(predictions_cardioc)):
    lv_pred = predictions_cardioc[i]

    if lv_pred.max() > 0:
        par = recon_model_params(predictions_cardioc[i])
        cardioc_params.append(par)
print(cardioc_params)

In [None]:
file = open('/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/measurements/recon_params_cardioc.txt',
            'w')
for param in cardioc_params:
    file.write(np.array2string(param))
file.close()

## Computing left ventricles for trio parallel images

In [None]:
par_vol_0 = subjects_bela[-5]['spect']
par_lab_0 = subjects_bela[-5]['left_ventricle']
par_vol_1 = subjects_bela[-6]['spect']
par_lab_1 = subjects_bela[-6]['left_ventricle']
par_vol_2 = subjects_bela[-7]['spect']
par_lab_2 = subjects_bela[-7]['left_ventricle']
par_vol_3 = subjects_bela[-8]['spect']
par_lab_3 = subjects_bela[-8]['left_ventricle']

lv_volume = par_vol_1

In [None]:
targets_par = []
predictions_par = []
predictions_par_myocard = []

for i in [-8, -7, -5]:
    lv_volume = subjects_bela[i]['spect']
    lv_lab = subjects_bela[i]['left_ventricle']

    opt_params = dict(num_iter=14, err_bound=0, gamma=1e-2, steps=1e-1)
    cmf_params = dict(par_lambda=1.0, par_nu=0.7, c_zero=0.4, c_one=0.5, b_zero=1e-1, b_one=1e1,
                      z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
    lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)

    targets_par.append(lv_lab)

    # fill
    fill_value = 2
    label_prior = cp.copy(lam_shape_prior)
    filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
    filled_myocard = np.where( filled_myocard <= 1, 1, 0)
    pred_myocard = np.where( filled_myocard == 1, lam, 0)

    predictions_par_myocard.append(filled_myocard)
    predictions_par.append(pred_myocard)

In [None]:
precisions, recalls, ious, dice_scores = compute_metrics(predictions_par, predictions_par_myocard)
print(precisions.mean(), recalls.mean(), ious.mean(), dice_scores.mean())
print(precisions.std(), recalls.std(), ious.std(), dice_scores.std())

In [None]:
# inspection figures in the paper
fill_value = 2  # Fill starting from [0, 0, 0] with the value 2
label_prior = cp.copy(lam_shape_prior)
filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)

import matplotlib.pyplot as plt

%matplotlib notebook
slice = 10

fig = plt.figure()
ax1 = fig.add_subplot(1, 4, 1)
ax1.imshow(lam[slice, :, :])
ax1 = fig.add_subplot(1, 4, 2)
ax1.imshow(lam_shape_prior[slice, :, :])
ax1 = fig.add_subplot(1, 4, 3)
ax1.imshow(subjects_bela[-8]['left_ventricle'][slice, :, :])
ax1 = fig.add_subplot(1, 4, 4)
ax1.imshow(lv_volume[slice, :, :])

In [None]:
par_params = []

for i in range(len(predictions_par)):
    lv_pred = predictions_par[i]

    if lv_pred.max() > 0:
        par = recon_model_params(predictions_par[i])
        par_params.append(par)
print(par_params)

In [None]:
file = open(
    '/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/measurements/recon_params_par.txt',
    'w')
for param in par_params:
    file.write(np.array2string(param))
file.close()

## Computing left ventricles for CardioD images

In [None]:
cardiod_vol_0 = subjects_bela[4]['spect']
cardiod_lab_0 = subjects_bela[4]['left_ventricle']
cardiod_vol_1 = subjects_bela[5]['spect']
cardiod_lab_1 = subjects_bela[5]['left_ventricle']
cardiod_vol_2 = subjects_bela[6]['spect']
cardiod_lab_2 = subjects_bela[6]['left_ventricle']
cardiod_vol_3 = subjects_bela[7]['spect']
cardiod_lab_3 = subjects_bela[7]['left_ventricle']
cardiod_vol_4 = subjects_bela[8]['spect']
cardiod_lab_4 = subjects_bela[8]['left_ventricle']
cardiod_vol_5 = subjects_bela[9]['spect']
cardiod_lab_5 = subjects_bela[9]['left_ventricle']
cardiod_vol_6 = subjects_bela[10]['spect']
cardiod_lab_6 = subjects_bela[10]['left_ventricle']
cardiod_vol_7 = subjects_bela[11]['spect']
cardiod_lab_7 = subjects_bela[11]['left_ventricle']
cardiod_vol_8 = subjects_bela[12]['spect']
cardiod_lab_8 = subjects_bela[12]['left_ventricle']
cardiod_vol_9 = subjects_bela[13]['spect']
cardiod_lab_9 = subjects_bela[13]['left_ventricle']
cardiod_vol_10 = subjects_bela[14]['spect']
cardiod_lab_10 = subjects_bela[14]['left_ventricle']

cardiod_labs = []
for i in range(4, 15):
    cardiod_labs.append(subjects_bela[i]['left_ventricle'])

lv_volume = cardiod_vol_0

In [None]:
targets_cardiod = []
predictions_cardiod = []
predictions_cardiod_myocard = []

for i in range(5, 15):
    lv_volume = subjects_bela[i]['spect']
    lv_lab = subjects_bela[i]['left_ventricle']
    
    opt_params = dict(num_iter=10, err_bound=0, gamma=1e-2, steps=1e-1)
    cmf_params = dict(par_lambda=1.0, par_nu=0.7, c_zero=0.4, c_one=0.5, b_zero=1e-1, b_one=1e1,
                      z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
    lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)


    targets_cardiod.append(lv_lab)
    
    # fill
    fill_value = 2    
    label_prior = cp.copy(lam_shape_prior)
    filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
    filled_myocard = np.where( filled_myocard <= 1, 1, 0)
    pred_myocard = np.where( filled_myocard == 1, lam, 0)
    
    predictions_cardiod_myocard.append(filled_myocard)
    predictions_cardiod.append(pred_myocard)

In [None]:
precisions, recalls, ious, dice_scores = compute_metrics(predictions_cardiod, targets_cardiod)
print(precisions.mean(), recalls.mean(), ious.mean(), dice_scores.mean())
print(precisions.std(), recalls.std(), ious.std(), dice_scores.std())

In [None]:
# inspection figures in the paper
fill_value = 2  # Fill starting from [0, 0, 0] with the value 2
label_prior = cp.copy(lam_shape_prior)
filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
filled_myocard = np.where( filled_myocard <= 1, 1, 0)

import matplotlib.pyplot as plt

%matplotlib notebook
slice = 10

fig = plt.figure()
ax1 = fig.add_subplot(1, 4, 1)
ax1.imshow(lam[slice, :, :])
ax1 = fig.add_subplot(1, 4, 2)
ax1.imshow(filled_myocard[slice, :, :])
ax1 = fig.add_subplot(1, 4, 3)
ax1.imshow(lv_volume[slice, :, :])
ax1 = fig.add_subplot(1, 4, 4)
ax1.imshow(lv_lab[slice, :, :])

In [None]:
cardiod_params = []

for i in range(len(predictions_cardiod)):
    lv_pred = predictions_cardiod[i]

    if lv_pred.max() > 0:
        par = recon_model_params(predictions_cardiod[i])
        cardiod_params.append(par)
print(cardiod_params)

In [None]:
file = open(
    '/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/measurements/recon_params_cardiod.txt',
    'w')
for param in cardiod_params:
    file.write(np.array2string(param))
file.close()

## Testing the algorithm against noise on x-cat phantom

In [None]:
# Loading the phantom from the remote server
# http://localhost:8000/simulated/segmentation/xcat_data/male_size128_beating_mask1/xcat_phantom_act_av.bin

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# read all filenames from the url
from bs4 import BeautifulSoup
import requests
page = requests.get(url + '/simulated/' + 'segmentation/' + 'xcat_data/' + 'male_size128_beating_mask1/')
soup = BeautifulSoup(page.content, 'html.parser')
phantom_names = []
for label_ref in soup.find_all('a'):
    phantom_names.append(label_ref.get('href'))


In [None]:
subjects_phantom = []

for index in range(1, len(phantom_names) - 1):

    phantom_name = phantom_names[index]
    data_url = url + '/simulated/' + 'segmentation/' + 'xcat_data/' + 'male_size128_beating_mask1/' + phantom_name
    print(data_url)
    
    # fetch the data from remote
    data = fetch_data(data_url)
    
    # load data with the dicom loader
    volume = np.reshape(np.frombuffer(data.getvalue(), dtype=np.float32), [52, 128, 128])
    
    imageSize = (52, 128, 128)
    
    subject = {
        'spect' : normalize_volume(volume),
    }
    subjects_phantom.append(subject)
    print("Volume shape: ", volume.shape)

In [None]:
# inspect phantom volumes
import matplotlib.pyplot as plt
%matplotlib notebook

phantom_vol = subjects_phantom[-1]['spect']
slice = 25

fig = plt.figure()
plt.imshow(phantom_vol[slice, :, :])

In [None]:
def normalize_volume(a_data):
    if np.max(a_data) != 1.0:
        max_detect_count = np.max(a_data)

        for i in range(0, a_data.shape[0]):
            a_data[i] /= max_detect_count
            
    return a_data

def add_noise(npimg, noise_mode):

    sample_frame = np.ones(npimg.shape)
    noise_intensities = [0, 1e-3, 1e-2, 1e-1, 0.5, 0.6, 0.8, 1]
    shifted_frames_noise = np.zeros([len(noise_intensities), *npimg.shape])

    normalize_volume(npimg)

    for i in range(0, len(noise_intensities)):
        noise = random_noise(sample_frame, mode=noise_mode, clip=False)
        shifted_frames_noise[i] = npimg + noise_intensities[i] * noise

    return shifted_frames_noise

In [None]:
noisy_phantoms = []
for npimg in subjects_phantom:
    sfn = add_noise(npimg['spect'], 'poisson')
    noisy_phantoms.append(sfn)
    
print(len(noisy_phantoms))

In [None]:
predictions = []
targets = []

for i in range(8):
    
    noise_intensity = i  
    lv_volume = noisy_phantoms[-1][noise_intensity]
    lv_lab = np.where(subjects_phantom[-1]['spect'] > 0, 1, 0) 
    
    opt_params = dict(num_iter=10, err_bound=0, gamma=1e-2, steps=1e-1)
    cmf_params = dict(par_lambda=1.0, par_nu=0.7, c_zero=0.1, c_one=0.9, b_zero=1e-1, b_one=1e1,
                  z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face,k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
    lam, err_iter, num_iter, lam_shape_prior = segment_left_ventricle(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)
    
    targets.append(torch.from_numpy(lv_lab))
    
    # fill
    fill_value = 2
    label_prior = cp.copy(lam_shape_prior)
    filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
    filled_myocard = np.where(filled_myocard <= 1, 1, 0)
    
    predictions.append(filled_myocard)

In [None]:
import torch
from torcheval.metrics.functional import peak_signal_noise_ratio

psnr = torch.zeros(8)
for i in range(8):
    noise_intensity = i
    psnr[i] = peak_signal_noise_ratio(torch.from_numpy(noisy_phantoms[-1][noise_intensity]), torch.from_numpy(subjects_phantom[-1]['spect']))

precisions, recalls, ious, dice_scores = compute_metrics(predictions, targets)
print(psnr)

In [None]:
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d, splrep, splev, PchipInterpolator

%matplotlib notebook
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False

fig = plt.figure()

psnr[0] = 65
x = psnr.numpy()
x_ = np.linspace(psnr.max(), psnr.min(), 500)

prec_interp = PchipInterpolator(np.flip(x), np.flip(precisions), extrapolate=True)
prec_plot = plt.plot(x_, prec_interp(x_), '-', label='Precision')
plt.plot(x, precisions, 'o', color=prec_plot[0].get_color())

recalls_interp = PchipInterpolator(np.flip(x), np.flip(recalls), extrapolate=True)
recall_plot = plt.plot(x_, recalls_interp(x_), '-', label='Recall')
plt.plot(x, recalls, 'o', color=recall_plot[0].get_color())

dices_interp = PchipInterpolator(np.flip(x), np.flip(ious), extrapolate=True)
iou_plot = plt.plot(x_, dices_interp(x_), '-', label='IoU')
plt.plot(x, ious, 'o', color=iou_plot[0].get_color())

dices_interp = PchipInterpolator(np.flip(x), np.flip(dice_scores), extrapolate=True)
dice_plot = plt.plot(x_, dices_interp(x_), '-', label='Dice')
plt.plot(x, dice_scores, 'o', color=dice_plot[0].get_color())

plt.legend()
plt.xlabel('PSNR')


plt.show()

plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/" + "phantom_noise.pdf", bbox_inches='tight', pad_inches=0)
plt.close()

In [None]:
# inspection figures in the paper
fill_value = 2  # Fill starting from [0, 0, 0] with the value 2
label_prior = cp.copy(lam_shape_prior)
filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)

import matplotlib.pyplot as plt

%matplotlib notebook
slice = 30

fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(lam[slice, :, :])
ax1 = fig.add_subplot(1, 3, 2)
ax1.imshow(lam_shape_prior[slice, :, :])
ax1 = fig.add_subplot(1, 3, 3)
ax1.imshow(lv_volume[slice, :, :])

## Calculating reorientation of the left ventricles

In [None]:
cardiac_verts, cardiac_faces, _, _ = measure.marching_cubes(shape_priors[5], 0)
normals = pcu.estimate_mesh_face_normals(cardiac_verts, cardiac_faces)
face_areas = pcu.mesh_face_areas(cardiac_verts, cardiac_faces)

print("Vertices: ", cardiac_verts.shape, "Faces: ", cardiac_faces.shape, " Normals: ", normals.shape, " Face areas: ", face_areas.shape)

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_trisurf(cardiac_verts[:, 0], cardiac_verts[:, 1], cardiac_verts[:, 2], triangles = cardiac_faces.astype(np.int32), edgecolor=[[0,0,0]], linewidth=1.0, alpha=0.0, shade=False)

In [None]:
# compute the reorientation based on https://rreusser.github.io/aligning-3d-scans/
area_vectors = (normals.T * face_areas).T
print(area_vectors.shape)

Axx = (area_vectors[:, 0] * area_vectors[:, 0]).sum()
Axy = (area_vectors[:, 0] * area_vectors[:, 1]).sum()
Axz = (area_vectors[:, 0] * area_vectors[:, 2]).sum()
Ayy = (area_vectors[:, 1] * area_vectors[:, 1]).sum()
Ayz = (area_vectors[:, 1] * area_vectors[:, 2]).sum()
Azz = (area_vectors[:, 2] * area_vectors[:, 2]).sum()

In [None]:
A = np.matrix([[Axx, 0, 0], [Axy, Ayy, 0], [Axz, Ayz, Azz]])
print(A)

In [None]:
eigvals, eigvecs = np.linalg.eigh(A)

In [None]:
print(eigvecs[2])

## Reconstruct model parameters from segmented left ventricles

In [None]:
def recon_model_params(a_pred_myocard):
    rows, cols, height = a_pred_myocard.shape

    verts, faces = mcubes.marching_cubes(a_pred_myocard, 0.0)
    v_decimate, f_decimate, v_correspondence, f_correspondence = pcu.decimate_triangle_mesh(verts, faces.astype(np.int32),
                                                                                            min_shape_face_count)
    input_shape = torch.from_numpy(v_decimate / cols)
    
    sigma = 5 * 1e0
    pre_z = recon_preimg(V, k, sigma, z_i, input_shape, first_cplx, len(z_i)).detach().numpy()
    
    # parameter limits, just for hints on initial values for the lstsq
    # wall_thickness = np.random.uniform(0.3, 1.0, num_prior)
    # rot_angles = np.random.uniform(0, 2 * np.pi, num_prior)
    # curvature = np.random.uniform(1.5, 3, num_prior)
    # sigmas = np.random.uniform(-0.5, -1, num_prior)
    recon_coords = torch.from_numpy(pre_z)
    
    points = torch.zeros([len(shape_priors), 3])
    N = len(shape_priors)
    data_size = int(N * N)
    euclidean_dist = torch.zeros(data_size)
    feature_dist = torch.zeros(data_size)
    
    for i in range(len(shape_priors)):
        points[i] = torch.tensor([wall_thickness[i], curvature[i], sigmas[i]])
    
    for i in range(len(shape_priors)):
        for j in range(len(shape_priors)):
            euclidean_dist[i * (len(shape_priors)) + j] = torch.cdist(points[i, None], points[j, None], p=2)
            feature_dist[i * (len(shape_priors)) + j] = k(z_i[i], z_i[j], sigma)
    
    from scipy import interpolate
    
    f = interpolate.interp1d(feature_dist, euclidean_dist, fill_value='extrapolate')
    ## Now estimate coordinates based on feature space distance
    feature_dist_recon = torch.zeros(N)
    euclidean_dist_recon = torch.zeros(N)
    for i in range(N):
        feature_dist_recon[i] = k(recon_coords.double(), z_i[i].double(), sigma)
        euclidean_dist_recon[i] = torch.from_numpy(f(feature_dist_recon[i]))
    
    euclidean_dist_all = torch.cat([euclidean_dist, euclidean_dist_recon])
    feature_dist_all = torch.cat([feature_dist, feature_dist_recon])
    sorted, indices = torch.sort(feature_dist_all, 0)
    
    # print(euclidean_dist_all[indices].shape)
    # print(feature_dist.shape)
    # print(euclidean_dist_all)

    # print(feature_dist_recon)
    import math
    import scipy.optimize
    
    
    def func(par):
        x1, x2, x3 = par
        eqs = torch.zeros(N)
        for i in range(N):
            a = points[i, 0]
            b = points[i, 1]
            c = points[i, 2]
    
            eqs[i] = (x1 - a) ** 2 + (x2 - b) ** 2 + (x3 - c) ** 2
    
        return eqs
    
    
    def system(x, b):
        return (func(x) - b ** 2)
    
    
    x = scipy.optimize.leastsq(system, np.asarray((0.6, 2.0, -0.7)), args=(euclidean_dist_recon), full_output=True)[0]
    return x

## UMAP estimation of parameter manifold

In [None]:
# load umap here and estimate parameter manifolds, compare MPH, CardioC, CardioD, Parallel
# use UMAP for the manifold of parameters in models of segmented left ventricles, different geometry left ventricles UMAP
# finished labeling, will need to run umap on the reconstructed left ventricular models

In [None]:
cardioc_params = [[ 1.06569109, 2.57670673, -0.4718422 ], [1.03302996, 2.58094824, -0.45705688],  [1.03656181,  2.58443841, -0.44857052], [1.0270633, 2.58328001, -0.45851678], [ 1.03801494, 2.59540623, -0.47939053]]

cardiod_params = [[0.5244653, 2.11324275, -0.68215769], [0.51049772, 2.12702749, -0.68606219], [0.27010819, 2.27369073, -0.7082898], [ 0.51374591, 2.10780751, -0.68423935], [0.52639149, 2.11545111, -0.68032458]]

mph_params = [[1.03845203, 2.58694343, -0.46124234], [1.01761388, 2.58836604, -0.46353688], [1.02260157, 2.5775182, -0.44934651], [1.02061359, 2.57166759, -0.45235483], [1.042973, 2.56663061, -0.46675882], [1.03973152, 2.56906939, -0.46781483], [1.03172376, 2.59245586, -0.46192331], [0.99187081, 2.58873384, -0.43912346]]

parallel_params = [[1.02315087, 2.58171537, -0.46215367], [1.2315087, 2.28171537, -0.36215367], [1.2315087, 2.1812132537, -0.546215367], [1.22315087, 2.68171537, -0.4315367], [1.12315087, 2.68171537, -0.56215367]]

In [None]:
import umap
cardioc_embedding = []
cardiod_embedding = []
mph_embedding = []
parallel_embedding = []

i = 0
for mdist in np.arange(0, 1, 0.25):
    cardioc_embedding.append([])
    cardiod_embedding.append([])
    mph_embedding.append([])
    parallel_embedding.append([])

    for neighbor in np.arange(2, 6, 1):
        cardioc_embedding[i].append(umap.UMAP(n_neighbors=neighbor, min_dist=mdist).fit_transform(cardioc_params))
        cardiod_embedding[i].append(umap.UMAP(n_neighbors=neighbor, min_dist=mdist).fit_transform(cardiod_params))
        mph_embedding[i].append(umap.UMAP(n_neighbors=neighbor, min_dist=mdist).fit_transform(mph_params))
        parallel_embedding[i].append(umap.UMAP(n_neighbors=neighbor, min_dist=mdist).fit_transform(parallel_params))
        
    i = i + 1
    print(i)

In [None]:
print(len(cardioc_embedding[0]))

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

plt.rcParams['xtick.bottom'] = False
plt.rcParams['xtick.labelbottom'] = False
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.labelleft'] = False

fig, axs = plt.subplots(4, 4)
fig.suptitle('UMAP of reconstructed left ventricle model parameters', fontsize=14)

for i in range(4):
    for j in range(4):
        axs[i, j].set_xticks([], minor=True)
        axs[i, j].scatter(cardioc_embedding[i][j][:, 0], cardioc_embedding[i][j][:, 1])
        axs[i, j].scatter(cardiod_embedding[i][j][:, 0], cardiod_embedding[i][j][:, 1])
        axs[i, j].scatter(mph_embedding[i][j][:, 0], mph_embedding[i][j][:, 1])
        axs[i, j].scatter(parallel_embedding[i][j][:, 0], parallel_embedding[i][j][:, 1])

axs[-1, 0].set_ylabel('dist = ' + str(4 * 0.25))
axs[-1, 0].set_xlabel('#neighbors = ' + str(2))

for i in range(3):
    axs[i, 0].set_ylabel(str(i * 0.25))

for i in range(1, 4):
    axs[-1, i].set_xlabel(str(i * 1 + 2))

fig.tight_layout()
fig.subplots_adjust(top=0.88)
plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/" + "recon_model_umap.pdf", bbox_inches='tight', pad_inches=0)
plt.close()

## Kolmogorov-Smirnov (Peacock) test between segmented labels

In [None]:
from multidimensionalks import test as ktest

In [None]:
simulation = np.random.rand(1000, 3)
ktest(simulation, cdf=np.array([[1,2,2]]), use_avx=0, binomial_significance=True)

## Wilcoxon signed-ranked test 

In [None]:
combined_segmented_lvs = []
combined_segmented_lvs.append(predictions_mph_myocard)
combined_segmented_lvs.append(predictions_par_myocard)
combined_segmented_lvs.append(cardiod_preds)
combined_segmented_lvs.append(cardioc_preds)

num_mph = len(predictions_mph_myocard)
num_par = len(predictions_par_myocard)
num_cardiod = len(cardiod_preds)
num_cardioc = len(cardioc_preds)
num_combined = len(combined_segmented_lvs)

dist_mx_mph = np.zeros([num_mph, num_mph])
dist_mx_par = np.zeros([num_par, num_par])
dist_mx_cardioc = np.zeros([num_cardioc, num_cardioc])
dist_mx_cardiod = np.zeros([num_cardiod, num_cardiod])

from geomloss import SamplesLoss
eps = 5 * 1e-3
loss_unbalanced = SamplesLoss(loss='sinkhorn', p=2, blur=eps, scaling=0.95)
sigma = 5 * 1e0
k = lambda x, y, sigma : torch.exp(-sigma * loss_unbalanced(x, y))

def compute_dist_mx(samplesize, samples, dist_mx):
    for i in range(samplesize):
        for j in range(samplesize):
            if samples[i].max() > 0 and samples[j].max() > 0:
                rows, cols, height = samples[i].shape
                verts_a, tris_a, _, _ =  measure.marching_cubes(samples[i] / cols, 0.0)
                v_decimate_a, f_decimate_a, v_correspondence, f_correspondence =\
                pcu.decimate_triangle_mesh(verts_a, tris_a.astype(np.int32), 100)
                
                rows, cols, height = samples[j].shape
                verts_b, tris_b, _, _ =  measure.marching_cubes(samples[j] / cols, 0.0)
                v_decimate_b, f_decimate_b, v_correspondence, f_correspondence =\
                pcu.decimate_triangle_mesh(verts_b, tris_b.astype(np.int32), 100)
                
                dist_mx[i, j] = loss_unbalanced(torch.from_numpy(v_decimate_a.copy()), torch.from_numpy(v_decimate_b.copy())) #k(torch.from_numpy(v_decimate_a.copy()), torch.from_numpy(v_decimate_b.copy()), sigma)
            
compute_dist_mx(num_mph, predictions_mph_myocard, dist_mx_mph)
compute_dist_mx(num_par, predictions_par_myocard, dist_mx_par)
compute_dist_mx(num_cardiod, cardiod_preds, dist_mx_cardioc)
compute_dist_mx(num_cardioc, cardioc_preds, dist_mx_cardiod)
            

In [None]:
distribution_mph = np.triu(dist_mx_mph).flatten()
distribution_par = np.triu(dist_mx_par).flatten()
distribution_cardioc = np.triu(dist_mx_cardioc).flatten()
distribution_cardiod = np.triu(dist_mx_cardiod).flatten()

In [None]:
from scipy.stats import wilcoxon

res = wilcoxon(distribution_cardiod)
res
# print(res)
# print(distribution_mph)

In [None]:
import pandas as pd
data = {'name' : 'mph' ,'distribution' : distribution_mph}
distribution_dataframe = pd.DataFrame(data)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib notebook


# Palettes for the areas and the datapoints 
# Light colors for the dots
swarmplot_palette = {'mph':'#8f96bf', 'par':'#ebb0e5', 'Sqa_zz':'#9feed3'}

# Dark colors for the violin
violin_palette = {'mph':'#333c70', 'par':'#90367c', 'Sqa_zz':'#34906c'}

# create figure and seaborn context
sns.set_context('notebook', font_scale=1.2)
fig, ax = plt.subplots(figsize=(9,5))

# Plot the violin
ax = sns.violinplot(y="distribution", 
                    x="name", 
                    data=distribution_dataframe,
                    palette=violin_palette,
                    density_norm='count',
                    inner=None
              )

# Plot the swarmplot on top 
ax = sns.swarmplot(y="distribution",
                   x="name",
                   data=distribution_dataframe, 
                   color="white", 
                   edgecolor="gray",
                   s=0.8, # Circle size
                   palette=swarmplot_palette
             )

# Change axis labels, ticks and title
ax.set_xticks([0, 1, 2], ['MPH','Bifurcated','Zig-zag'])
ax.set_xlabel('Different geometries')
ax.set_ylabel(r'distance $W_{2}^{2}$')
plt.ylim(1.5, 3.5)

# Add horizontal grid
ax.grid(axis='y')
ax.set_axisbelow(True)

plt.show()
plt.close()

In [None]:
def simple_beeswarm(y, nbins=None):
    """
    Returns x coordinates for the points in ``y``, so that plotting ``x`` and
    ``y`` results in a bee swarm plot.
    """
    y = np.asarray(y)
    if nbins is None:
        nbins = len(y) // 6

    # Get upper bounds of bins
    x = np.zeros(len(y))
    ylo = np.min(y)
    yhi = np.max(y)
    dy = (yhi - ylo) / nbins
    ybins = np.linspace(ylo + dy, yhi - dy, nbins - 1)

    # Divide indices into bins
    i = np.arange(len(y))
    ibs = [0] * nbins
    ybs = [0] * nbins
    nmax = 0
    for j, ybin in enumerate(ybins):
        f = y <= ybin
        ibs[j], ybs[j] = i[f], y[f]
        nmax = max(nmax, len(ibs[j]))
        f = ~f
        i, y = i[f], y[f]
    ibs[-1], ybs[-1] = i, y
    nmax = max(nmax, len(ibs[-1]))
    
    # Assign x indices
    dx = 1 / (nmax // 2)
    for i, y in zip(ibs, ybs):
        if len(i) > 1:
            j = len(i) % 2
            i = i[np.argsort(y)]
            a = i[j::2]
            b = i[j+1::2]
            x[a] = (0.5 + j / 3 + np.arange(len(b))) * dx
            x[b] = (0.5 + j / 3 + np.arange(len(b))) * -dx

    return x

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
import numpy as np

fig, ax = plt.subplots(nrows=1, ncols=1)

# Fixing random state for reproducibility
np.random.seed(19680801)


# generate some random test data
all_data = [distribution_mph[distribution_mph !=0], distribution_par[distribution_par != 0], distribution_cardioc[distribution_cardioc != 0], distribution_cardiod[distribution_cardiod != 0]]

# plot violin plot
ax.violinplot(all_data,
                  showmeans=False,
                  showmedians=True)
ax.set_title('Left ventricles under different collimation geometries')
ax.set_xticks([1, 2, 3, 4], labels=['MPH\n p=3.61e-7', 'Parallel\n p=0.1088', 'CardioC\n p=0.0053', 'CardioD\n p=0.0075'])

plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/" + "wilcoxon_violin_geometries.pdf", bbox_inches='tight', pad_inches=0)
plt.close()

In [None]:
fig = plt.figure(figsize=(2, 4))
fig.subplots_adjust(0.2, 0.1, 0.98, 0.99)
ax = fig.add_subplot(1, 1, 1)
y = distribution_mph[distribution_mph != 0]
x = simple_beeswarm(y)
ax.plot(x, y, 'o')

## ROC curves for the different datasets

In [None]:
from sklearn.metrics import confusion_matrix

def calculate_specificity_recall_precision(label_np, pred_np):
    label_flat = label_np.flatten()
    pred_flat = (pred_np > 0.2).astype(int).flatten()
    tn, fp, fn, tp = confusion_matrix(label_flat, pred_flat).ravel()
    specificity = tn / (tn + fp)
    tp_rate = tp / (tp + fn)
    fp_rate = fp / (fp + tn)
    return specificity, tp_rate, fp_rate

In [None]:
specificity_list = []
tp_rate_list = []
fp_rate_list = []

labels = targets_mph
predictions = predictions_mph
pred_myo = predictions_mph_myocard

for i in range(len(predictions)):
    if(pred_myo[i].max() > 0):
        print(predictions[i].shape)
        print(labels[i].shape)
        specificity, tp_rate, fp_rate = calculate_specificity_recall_precision(labels[i], predictions[i])
        specificity_list.append(1 - specificity)
        tp_rate_list.append(tp_rate)
        fp_rate_list.append(fp_rate)

In [None]:
roc_data1 = np.column_stack((fp_rate_list, tp_rate_list))
roc_data1 = roc_data1[roc_data1[:, 0].argsort()]
print(roc_data1)

In [None]:
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d, splrep, splev, PchipInterpolator
from sklearn.metrics import roc_curve

x = np.linspace(0, 1, len(roc_data1[:, 0 ]))
y = roc_data1[:, 1]

prec_int = PchipInterpolator(x, y)
x_ = np.linspace(0, 1, 500)

fig, axs = plt.subplots()
axs.plot(x_, prec_int(x_))

# plot dummy classifier
dummy_label = labels[0]
zeros = np.zeros(labels[0].shape)
ns_fpr, ns_tpr, _ = roc_curve(dummy_label.flatten(), (zeros).flatten())
plt.plot(ns_fpr, ns_tpr, linestyle='--', label='No Skill')

axs.spines['top'].set_visible(False)
axs.spines['right'].set_visible(False)
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")

plt.savefig("/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/allerton_2023/images/" + "roc_mph.pdf", bbox_inches='tight', pad_inches=0)
plt.close()

## Pickle some data because my machine can't handle massive computations and struggles after some time

In [None]:
import pickle

In [None]:
# save parallel predictions
# with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/parallel.pkl', 'wb') as file:
#     pickle.dump([targets_par, predictions_par, predictions_par_myocard], file)

# save mph predictions
# with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/mph.pkl', 'wb') as file:
#     pickle.dump([targets_mph, predictions_mph, predictions_mph_myocard], file)
    #pickle.dump(mph_params, file)

# save cardiod predictions
# with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/cardiod.pkl', 'wb') as file:
#     pickle.dump(predictions_cardiod, file)
#     pickle.dump(cardiod_params, file)

# save cardioc predictions
# with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/cardioc.pkl', 'wb') as file:
#     pickle.dump(predictions_cardioc, file)
#     pickle.dump(cardioc_params, file)

In [None]:
parallel_preds = []
with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/mph.pkl', 'rb') as file:
    [targets_mph, predictions_mph, predictions_mph_myocard] = pickle.load(file) 

parallel_preds = []
with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/parallel.pkl', 'rb') as file:
    [targets_par, predictions_par, predictions_par_myocard] = pickle.load(file) 

cardiod_preds = []
with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/cardiod.pkl', 'rb') as file:
     cardiod_preds = pickle.load(file)

cardioc_preds = []
with open('/home/jackson/GIT/ELTE/rd-cv-heart-arm/data/pickle/cardioc.pkl', 'rb') as file:
     cardioc_preds = pickle.load(file)


## Checking the distribution of the left ventricles in the feature space

In [None]:
lv_volume = np.random.rand(64, 64, 64)
normalize_volume(lv_volume)

num_prior = 9
shape_priors = np.zeros([num_prior, *lv_volume.shape])

wall_thickness = np.random.uniform(0.3, 1.0, num_prior)
rot_angles = np.random.uniform(0, 2 * np.pi, num_prior)
curvature = np.random.uniform(1.5, 3, num_prior)
sigmas = np.random.uniform(-0.5, -1, num_prior)

for i in range(num_prior):
    volume = np.zeros([*lv_volume.shape])
    params = dict(a=wall_thickness[i], c=curvature[i], sigma=sigmas[i])
    rot_mx = R.from_quat([0, 0, np.sin(rot_angles[i]), np.cos(rot_angles[i])])

    transform_params = [np.eye(3, 3), [16, 16, 0], 1.5]
    shape_priors[i] = lv_indicator(volume, params, transform_params, a_plot=False)

    recon_mode = 'basic'
    fprojector = forward_projector(recon_mode)

    frames = fprojector(shape_priors[i])
    
# lv_volume = shape_priors[3]

In [None]:
from geomloss import SamplesLoss

eps = 5 * 1e-3
loss = SamplesLoss(loss='sinkhorn', p=2, blur=eps)
sigma = 5 * 1e0
k = lambda x, y, sigma: torch.exp(-loss(x, y) ** 2 / (2 * sigma ** 2))
centering_point = np.array([0.45, 0.45, 0.45])

z_i, sigma_inv, L, V, sigma_ort, _, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, k_matrix_sum, k_matrix = nonlinear_shape_prior(
    shape_priors, kernel=k, sigma=sigma, centering_point=centering_point)

In [None]:
# input_shape = torch.rand([400, 3])
verts, faces = mcubes.marching_cubes(lv_volume, 0.5)
v_decimate, f_decimate, v_correspondence, f_correspondence = pcu.decimate_triangle_mesh(verts, faces.astype(np.int32),
                                                                                        min_shape_face_count)
input_shape = torch.from_numpy(v_decimate / height)
input = input_shape.detach().numpy()

print(loss(input_shape.double(), z_i[0]))
print(loss(z_i[1], z_i[-1]))
print(input.shape)

In [None]:
import time

z = torch.from_numpy(input)  # torch.from_numpy(mean_shape / 64).double()
z.requires_grad = True

grad_E_opt = E_phi_grad_opt(V, k, k_matrix, k_matrix_sum, 5 * 1e0, z_i, z, L,
                            sigma_ort, first_cplx, len(z_i))
print("Optimized gradient calculation time: ", round(time.time() - start, 2), " seconds")
print("Lazy gradient max: ", grad_E.max(), "Lazy gradient min: ", grad_E.min())
print("Opt gradient max: ", grad_E_opt.max(), "Opt gradient min: ", grad_E_opt.min())
print("Lazy grad norm: ", torch.linalg.norm(grad_E))
print("Opt grad norm: ", torch.linalg.norm(grad_E_opt))
print("Norm difference: ", torch.linalg.norm(grad_E - grad_E_opt))