In [None]:
from pathlib import Path
from typing import Tuple

import yaml
from tqdm.notebook import tqdm
import nrrd
import numpy as np
import pandas as pd
import torchio as tio
import torch
from unet import UNet
import point_cloud_utils as pcu
import copy as cp

from skimage.metrics import structural_similarity as ssim
from skimage.util import montage
from sklearn.metrics import confusion_matrix

from scipy.spatial.transform import Rotation as R

from src.tools.recon.projector import forward_projector, backward_projector
from src.tools.manip.manip import normalize_volume

# data fetching and handling
from data.check_database import load_remote_data
from data.fetch_data import fetch_data
from src.tools.data.loadvolumes import LoadVolumes

import matplotlib.pyplot as plt

# CMF algorithm shape prior enhancement 
from src.algs.arm import lv_indicator
from src.tools.cmf.cmf_shape_prior import cmf_shape_prior
from src.tools.kde.nonlinear_shape_prior import nonlinear_shape_prior, nonlinear_shape_prior_grad

In [None]:
# copied from supervised_learning.py

path_experiment_conf = Path().resolve().joinpath('../../../exp/supervised_training_best.yaml')

with open(path_experiment_conf, 'r') as file:
    conf = yaml.load(file, Loader=yaml.Loader)

In [None]:
conf

# Data fetching from remote server 

In [None]:
dicom_loader = LoadVolumes()

# initialize data fetching from remote, configuration is in data/remote.yml
data_loaded = False
url, datasets = load_remote_data()

# read all filenames from the url
from bs4 import BeautifulSoup
import requests
page = requests.get(url + '/recon/' + 'spie_2024/' + 'misc/' + 'label/')
soup = BeautifulSoup(page.content, 'html.parser')
label_names = []
for label_ref in soup.find_all('a'):
    label_names.append(label_ref.get('href'))

page = requests.get(url + '/recon/' + 'spie_2024/' + 'misc/' + 'data/')
soup = BeautifulSoup(page.content, 'html.parser')
data_names = []
for label_ref in soup.find_all('a'):
    data_names.append(label_ref.get('href'))

subjects = []
subjects_data = []

# fetch specific patient data
for index in range(len(data_names)):

    dicom_name = data_names[index]
    label_name = label_names[index]
    data_url = url + '/recon/' + 'spie_2024/' + 'misc/' + 'data/' + dicom_name
    label_url = url + '/recon/' + 'spie_2024/' + 'misc/' + 'label/' + label_name
    
    # fetch the data from remote
    data = fetch_data(data_url)
    lab = fetch_data(label_url)
    
    # load data with the dicom loader
    volume, data_loaded = dicom_loader.LoadSinglePatient(data)
    header = nrrd.read_header(lab)
    labels = nrrd.read_data(header, lab)
    
    # looks like the label export is a bit tricky so loading shall be updated
    prob_val_1 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0))
    prob_val_2 = np.sum(np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0))
    
    if prob_val_1 > prob_val_2:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 2, 1, 0)
    else:
        labels = np.where(np.transpose(labels, [2, 1, 0]) == 1, 1, 0)
        
    subject = tio.Subject(
        spect=tio.ScalarImage(tensor=volume[None, ...]),
        left_ventricle=tio.LabelMap(tensor=labels[None, ...])
    )
    subjects.append(subject)
    
    age, gender, weight, height = dicom_loader.CalculatePatientStatistics()
    subject_data = {
        'age' : age,
        'gender' : gender,
        'weight' : weight,
        'height' : height
    }
    subjects_data.append(subject_data)

    print("Volume shape: ", volume.shape, "Labels shape:", labels.shape)

    # normalizing the frame values
    normalize_volume(volume)

assert (data_loaded)

In [None]:
print(url)
print(datasets)
print(data_url)
print(lab)

In [None]:
import matplotlib.pyplot as plt

%matplotlib notebook
slice = 40

fig = plt.figure()
ax1 = fig.add_subplot(1, 3, 1)
ax1.imshow(volume[slice, :, :])
ax2 = fig.add_subplot(1, 3, 2)
ax2.imshow(labels[slice, :, :])
plt.show()

In [None]:
print(subjects[20]['spect'])

# Load model and check devices

In [None]:
# copied from supervised_learning.py

def get_model_and_optimizer(
    config: dict,
    device: str
) -> Tuple[torch.nn.Module, torch.optim.Optimizer]:
    """

    :param config:
    :param device:
    :return:
    """
    model = UNet(
        in_channels=1,
        out_classes=2,
        dimensions=3,
        upsampling_type='linear',
        padding=True,
        activation='PReLU',
        **config['model']['UNet']
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['optimizer']['learning_rate']
    )

    return model, optimizer

In [None]:
path_saved_models = Path().resolve().parent.parent.joinpath('saved_models')

# supervised fine-tuning experiment name
experiment = conf['experiment_name']
path_weights = path_saved_models.joinpath(f"{experiment}.pth")

weights = torch.load(path_weights)['weights']

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

model, optimizer = get_model_and_optimizer(conf, device)

In [None]:
model.load_state_dict(weights)

# Segmentation with trained model 

In [None]:
def load_image(path):
    
    data, header = nrrd.read(path)
    data = data.astype(np.float32)
    affine = np.eye(4)
    
    return data, affine


def prepare_batch(batch, device):
    
    inputs = batch['spect'][tio.DATA].to(device)
    targets = batch['left_ventricle'][tio.DATA].to(device)
    
    return inputs, targets


class Visualizer:
    
    def montage_nrrd(self, image):
        if len(image.shape) > 2:
            return montage(image)
        else:
            warnings.warn('Pass a 3D volume', RuntimeWarning)
            return image
        
    def visualize(self, image, mask=None):
        
        if mask is None:
            fig, axes = plt.subplots(1, 1, figsize=(10, 10))
            axes.imshow(self.montage_nrrd(image))
            axes.set_axis_off()
        else:
            fig, axes = plt.subplots(1, 2, figsize=(40, 40))
        
            for i, data in enumerate([image, mask]):
                axes[i].imshow(self.montage_nrrd(data))
                axes[i].set_axis_off()
 

def compute_metrics(prediction, target):
    epsilon=1e-9
    
    pred = prediction.argmax(dim=1)
    targ = target.argmax(dim=1)
    p1 = 1 - pred
    g1 = 1 - targ
    
    tp = (targ * pred).sum(dim=(1, 2, 3))
    fp = (pred * g1).sum(dim=(1, 2, 3))
    fn = (p1 * targ).sum(dim=(1, 2, 3))
    
    precision = (tp / (tp + fp)).mean().cpu().numpy().item()
    recall = (tp / (tp + fn)).mean().cpu().numpy().item()
    iou = (tp / (tp + fp + fn)).mean().cpu().numpy().item()
    dice_score = ((2 * tp) / (2 * tp + fp + fn + epsilon)).mean().cpu().numpy().item()
    
    return precision, recall, iou, dice_score

In [None]:
# subjects = []
target_shape = (128, 128, 128)

transform_pipeline = tio.Compose([
    tio.Resample(subjects[20]['spect']),
    tio.ToCanonical(),
    #tio.CropOrPad(target_shape=target_shape, mask_name="left_ventricle"),
    tio.ZNormalization(),
    tio.OneHot()
])

In [None]:
dataset = tio.SubjectsDataset(subjects, transform=transform_pipeline)
print(f"Dataset size: {len(dataset)} subjects")

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

In [None]:
def calculate_specificity_recall_precision(label_np, pred_np):
    label_flat = label_np.flatten()
    pred_flat = (pred_np > 0.5).astype(int).flatten()
    tn, fp, fn, tp = confusion_matrix(label_flat, pred_flat).ravel()
    specificity = tn / (tn + fp)
    recall = tp / (tp + fn)
    precision = tp / (tp + fp)
    return specificity, recall, precision

In [None]:
FOREGROUND = 1
vis = Visualizer()

model.eval()

num_samples = 0
specificity_list = []
recall_list = []
precision_list = []


for batch_idx, batch in enumerate(tqdm(data_loader)):

    inputs, targets = prepare_batch(batch, device)
    
    with torch.no_grad():

        predictions = model(inputs).softmax(dim=1)
        probabilities = predictions[:, FOREGROUND:].cpu()
    

    for i in range(len(batch['spect'][tio.DATA])):
    
        spect = batch['spect'][tio.DATA][i].permute(3, 0, 1, 2)
        label = batch['left_ventricle'][tio.DATA][i][1:, ...].permute(3, 0, 1, 2)
        pred = probabilities[i].permute(3, 0, 1, 2)
        
        # vis.visualize(
        #     np.squeeze(label.permute(1,0,2,3).numpy(), axis=0),
        #     np.squeeze(pred.permute(1,0,2,3).numpy(), axis=0)
        # )
        
        label_np = label.squeeze().permute(1, 2, 0).numpy()
        pred_np = pred.squeeze().permute(1, 2, 0).numpy()
        
        num_samples += 1
        
        specificity, recall, precision = calculate_specificity_recall_precision(label_np, pred_np)
        specificity_list.append(1 - specificity)
        recall_list.append(recall)
        precision_list.append(precision)


In [None]:
precision, recall, iou, dice = compute_metrics(predictions, targets)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"IoU: {iou:.4f}")
print(f"Dice score: {dice:.4f}")

In [None]:
predictions.shape

In [None]:
pred = predictions.argmax(dim=1) # predictions
target = targets.argmax(dim=1) # labels
lv_spect = batch['spect'][tio.DATA][0]
%matplotlib notebook
plt.imshow(lv_spect[0, :, :, 16].cpu().numpy())

# Running shape prior enhanced CMF on predictions

In [None]:
lv_volume = np.zeros([64, 64, 64])

num_prior = 9
shape_priors = np.zeros([num_prior, *lv_volume.shape])

wall_thickness = np.random.uniform(0.3, 1.0, num_prior)
rot_angles = np.random.uniform(0, 2 * np.pi, num_prior)
curvature = np.random.uniform(1.5, 3, num_prior)
sigmas = np.random.uniform(-0.5, -1, num_prior)

for i in range(num_prior):
    volume = np.zeros([*lv_volume.shape])
    params = dict(a=wall_thickness[i], c=curvature[i], sigma=sigmas[i])
    rot_mx = R.from_quat([0, 0, np.sin(rot_angles[i]), np.cos(rot_angles[i])])

    transform_params = [np.eye(3, 3), [16, 16, 0], 1.5]
    shape_priors[i] = lv_indicator(volume, params, transform_params, a_plot=False)

In [None]:
# u_init = pred[0]
# lv_volume = lv_spect[0]
# 
# sigma_inv, mean_shape, dec_faces = nonlinear_shape_prior(shape_priors, 1.0, 52)
# 
# opt_params = dict(num_iter=4, err_bound=0, gamma=1e-1, steps=1e-1)
# cmf_params = dict(par_lambda=10, u_init=u_init.cpu().to(torch.float32) ,par_nu=1, c_zero=0.1, c_one=0.8, b_zero=0.1, b_one=0.7, sigma_inv=sigma_inv, mean_shape=mean_shape, faces=dec_faces)
# lam, err_iter, num_iter = cmf_shape_prior(a_volume=lv_volume.cpu(), a_opt_params=opt_params, a_algo_params=cmf_params)

# Running CMF shape prior on hypoperfused hearts

In [None]:
path_images = [
    "/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/data/tc99_female_no_defect.nrrd",
    "/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/data/tc99_inferior_perf_defect.nrrd",
    "/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/data/tc99m_reversible_defect.nrrd",
    "/home/jackson/GIT/ELTE/papers/left_ventricle_segmentation/data/tc99m_stable_perfusion_defect.nrrd"
]

images = []

# load images from local folder
for path in path_images:
    img, _ = load_image(path)
    images.append(torch.from_numpy(img).to(device))

# running the network on ill conditioned patients
pred = []

target_shape = (64, 64, 64)
transform = tio.Compose([
    tio.CropOrPad(target_shape=target_shape),
    tio.ZNormalization(),
    tio.OneHot()
])

subjects = []

for path in path_images:
    with torch.no_grad():
        subject = tio.Subject(
            lv_volume = tio.ScalarImage(path),
            reader=load_image
        )
        subject.load()
        transformed = transform(subject)
        subjects.append(transformed)
        
        img_pred = model(transformed['lv_volume'][tio.DATA][None, ...].to(device)).softmax(dim=1)[:, FOREGROUND].cpu() # for debugging purpose
        pred.append(img_pred)

In [None]:
# run shape prior based CMF enhancement on predictions
z_i, sigma_inv, L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face  = nonlinear_shape_prior(shape_priors, 1.0, 52)

lam = []
for i in range(len(images)):
    u_init = pred[i][0]
    lv_volume = subjects[i]['lv_volume'][tio.DATA][0]
    
    opt_params = dict(num_iter=10, err_bound=0, gamma=1e-1, steps=1e-1)
    cmf_params = dict(u_init=u_init.cpu().to(torch.float32), par_lambda=1, par_nu=5, c_zero=0.0, c_one=0.7, b_zero=1e-1, b_one=1e1,
                      z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=5 * 1e-3, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face)
    
    lam_, err_iter, num_iter = cmf_shape_prior(a_volume=lv_volume.cpu(), a_opt_params=opt_params, a_algo_params=cmf_params)
    lam.append(lam_)

## Plotting the results

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook

image_ind = 1
tra_slice_ind = 32 # 32 stblprfdfct 30:50, 15:35 | 32 infrdfct 30:50, 10:30
vla_slice_ind = 20 # 25 stblprfdfct 30:50, 15:35 | 20 infrdfct 30:50, 20:40
sa_slice_ind = 40 # 40 stblprfdfct 15:35, 20:40 | 40 infrdfct 10:30, 20:40
fig, axs = plt.subplots(1, 3)
axs[0].imshow(images[image_ind][sa_slice_ind, :, :].cpu())
axs[1].imshow(pred[image_ind][0, sa_slice_ind, :, :].cpu())
axs[2].imshow(lam[image_ind][sa_slice_ind, :, :].cpu())
plt.show()
plt.close()

In [None]:
# patient plotting and saving
imgs = []
imgs.append(images[image_ind][sa_slice_ind, :, :].cpu())
imgs.append(pred[image_ind][0, sa_slice_ind, :, :].cpu())
imgs.append(lam[image_ind][sa_slice_ind, :, :].cpu())

labs = ['img', 'pred', 'pred_cmf']
for i in range(len(imgs)):
    fig = plt.imshow(imgs[i][10:30, 20:40])
    ax = fig.axes
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    
    # plt.savefig('pat_'+ labs[i] +'_tc99_inferior_perf_defect_' + 'sa' + '.png', bbox_inches='tight', pad_inches=0)
    #plt.close()

# Running CMF shape prior on labeled dataset and evaluation metrics

In [ ]:
from geomloss import SamplesLoss
eps = 5 * 1e-3
loss_unbalanced = SamplesLoss(loss='sinkhorn', p=2, blur=eps, scaling=0.95)
sigma = 5 * 1e0

k = lambda x, y, sigma : torch.exp(-sigma * loss_unbalanced(x, y))
centering_point = np.array([0.45, 0.45, 0.45])

z_i, sigma_inv, L, V, sigma_ort, sigma, first_cplx, min_shape_face_count, mean_shape, mean_shape_face, k_matrix_sum, k_matrix  = nonlinear_shape_prior(shape_priors, kernel=k, sigma=sigma, centering_point=centering_point)

In [None]:
cmf_pred=[]
opt_params = dict(num_iter=10, err_bound=0, gamma=1e-1, steps=1e-1)

FOREGROUND = 1
for batch_idx, batch in enumerate(tqdm(data_loader)):
    inputs, targets = prepare_batch(batch, device)
    with torch.no_grad():
        probabilities = model(inputs).softmax(dim=1)[:, FOREGROUND:].cpu()
    
    cmf_pred = []
    for i in range(inputs.shape[0]):
        u_init = probabilities[i]        
        lv_volume = subjects[i, 0]
    
        opt_params = dict(num_iter=14, err_bound=0, gamma=1e-2, steps=1e-1)
        cmf_params = dict(u_init=u_init.cpu().to(torch.float32), par_lambda=1.0, par_nu=0.7, c_zero=0.4, c_one=0.5, b_zero=1e-1, b_one=1e1,
                          z_i=z_i, sigma_inv=sigma_inv, L=L, V=V, sigma_ort=sigma_ort, sigma=sigma, first_cplx=first_cplx, min_shape_face_count=min_shape_face_count, mean_shape=mean_shape, mean_shape_face=mean_shape_face, k_matrix_sum=k_matrix_sum, k_matrix=k_matrix, kernel=k)
        lam, err_iter, num_iter, lam_shape_prior = cmf_shape_prior(a_volume=torch.from_numpy(lv_volume), a_opt_params=opt_params, a_algo_params=cmf_params)
        
        # fill
        fill_value = 2
        label_prior = cp.copy(lam_shape_prior)
        filled_myocard = pcu.flood_fill_3d(label_prior, [0, 0, 0], fill_value)
        filled_myocard = np.where( filled_myocard <= 1, 1, 0)
        pred_myocard = np.where( filled_myocard == 1, lam, 0)
    
        cmf_pred.append(filled_myocard)

## Compute metric results

In [None]:
pred_tor = torch.zeros([len(cmf_pred), *cmf_pred[0].shape]).to(device)
for i in range(len(cmf_pred)):
    pred_tor[i] = cmf_pred[i].to(device)
    
epsilon=1e-9

pred = pred_tor.argmax(dim=0)
targ = targets[0].argmax(dim=0)
p1 = 1 - pred
g1 = 1 - targ

tp = (targ * pred).sum(dim=(0, 1, 2))
fp = (pred * g1).sum(dim=(0, 1, 2))
fn = (p1 * targ).sum(dim=(0, 1, 2))

precision = (tp / (tp + fp)).mean().cpu().numpy().item()
recall = (tp / (tp + fn)).mean().cpu().numpy().item()
iou = (tp / (tp + fp + fn)).mean().cpu().numpy().item()
dice_score = ((2 * tp) / (2 * tp + fp + fn + epsilon)).mean().cpu().numpy().item()

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"IoU: {iou:.4f}")
print(f"Dice score: {dice:.4f}")