In [16]:
import matplotlib.pyplot as plt
import numpy as np

from pydicom import dcmread
import cv2
from mpl_toolkits.mplot3d import Axes3D
from pathlib import Path
import natsort
import json
import h5py
import pandas as pd
import nibabel as nib
from dtd import autopatch
import pprint as pp
import plotly.graph_objects as go
import os
%matplotlib widget


In [4]:
def get_pixel_coordinates(dcm_file):
    
    img = dcm_file.pixel_array
    img_dim = img.shape 
    ipp = dcm_file.ImagePositionPatient # upper left-hand corner of the image. (x,y,z)
    iop = dcm_file.ImageOrientationPatient

    X_iop = np.asarray(iop[:3], dtype=float)
    Y_iop = np.asarray(iop[3:], dtype=float)
    ps = dcm_file.PixelSpacing 

    A = np.array([[X_iop[0]*ps[0], Y_iop[0]*ps[1], 0, ipp[0]],
                  [X_iop[1]*ps[0], Y_iop[1]*ps[1], 0, ipp[1]],
                  [X_iop[2]*ps[0], Y_iop[2]*ps[1], 0, ipp[2]]]) # transformation matrix that maps pixel coordinates to physical coordinates.
    
    P_out = np.zeros((img_dim[0], img_dim[1], 4))
    for i in range(img_dim[1]): 
        for j in range(img_dim[0]):
            B = [[i],[j],[0],[1]]
            P_out[j,i,:] = np.vstack((np.matmul(A, B), img[j,i])).T
    return P_out

def get_3d_coordinates(dcm_file, i, j):
    ipp = dcm_file.ImagePositionPatient
    iop = dcm_file.ImageOrientationPatient
    ps = dcm_file.PixelSpacing

    X_iop = np.asarray(iop[:3], dtype=float)
    Y_iop = np.asarray(iop[3:], dtype=float)
    
    A = np.array([[X_iop[0]*ps[0], Y_iop[0]*ps[1], 0, ipp[0]],
                  [X_iop[1]*ps[0], Y_iop[1]*ps[1], 0, ipp[1]],
                  [X_iop[2]*ps[0], Y_iop[2]*ps[1], 0, ipp[2]]])

    B = np.array([[i], [j], [0], [1]])
    physical_coords = np.matmul(A, B)
    
    return physical_coords[:3].flatten()

def load_lax(dcm_folder, target_frame_num):
    files = natsort.natsorted([f.name for f in dcm_folder.iterdir()])

    frame_int = target_frame_num
    file_list_lax = [files[frame_int]]
    return file_list_lax
    


In [None]:

subject_id = 27
dcm_folder_sax = Path(Path(Path.home(), f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/{subject_id}/DICOM_files") )
img_paths = [f for f in os.listdir(dcm_folder_sax) if f!= ".dcm"]
dcm_folder_lax = Path(Path.home(), f"/Users/giuliamonopoli/Desktop/PhD /Data/MAD_OUS_sorted/{subject_id}/cine/4ch/")
target_frame_num = 13 # ES frame number

file_list_sax =  sorted(img_paths, key=lambda x: int(os.path.basename(x).split('sliceloc_')[1].split('.')[0]))
file_list_lax = load_lax(dcm_folder_lax, target_frame_num)

dcm_sax = [dcmread(dcm_folder_sax.joinpath(Path(f))) for f in file_list_sax]
dcm_lax = [dcmread(dcm_folder_lax.joinpath(Path(f))) for f in file_list_lax]

In [6]:
" Get MV insertion points from DeepValve annotations"

from dataclasses import dataclass, field, InitVar
from typing import List, Dict
@dataclass
class Annotation:
    mv_insert_septal: List[List[int]]
    mv_insert_lateral: List[List[int]]
    lv_base_septal: List[List[int]]
    lv_base_lateral: List[List[int]]
    leaflet_septal: List[List[int]]
    leaflet_lateral: List[List[int]]
@dataclass
class PatientData:
    patient_name: str
    key_frames: List[str]
    annotations: Dict[str, Annotation]
    bounding_box: List[int]
    flags: InitVar[List[int]] = None
    
def get_patient_data_list(patient_names: List[str], data_list: List[PatientData]) -> List[PatientData]:
    return [patient_data for patient_data in data_list if patient_data.patient_name in patient_names]


with open("/Users/giuliamonopoli/Desktop/PhD /deepvalve/data/new_annotations", "r") as json_file:
        data = json.load(json_file)
    
patient_data = [PatientData(**data[i]) for i, _ in enumerate(data)]
def get_patient_data(patient_name, patients_list):
    for patient in patients_list:
        if patient.patient_name == patient_name:
            return patient
    return None  



specific_patient_data = get_patient_data(f"MAD_{subject_id}_0", patient_data)
sd = specific_patient_data.annotations
mv_insert_septal = np.array(sd[f"{target_frame_num}"]["mv_insert_septal"])
mv_insert_lateral = np.array(sd[f"{target_frame_num}"]["mv_insert_lateral"])

In [9]:
n_max_sax = len(dcm_sax)
choose_sax = np.arange(0, n_max_sax)

########### Transform LAX pixel coordinates to physical coordinates ###########
dcm_file = dcm_lax[0]
P = get_pixel_coordinates(dcm_file)
px = dcm_file.PixelSpacing

########### Transform MV insertion points  to physical coordinates ###########
i, j =  int(np.array(mv_insert_lateral[0][1])//2) , int(np.array(mv_insert_lateral[0][2])//2) 
i2, j2 = int(np.array(mv_insert_septal[0][1])//2) , int(np.array(mv_insert_septal[0][2])//2)
physical_coords = get_3d_coordinates(dcm_file, i, j)
physical_coords2 = get_3d_coordinates(dcm_file, i2, j2)



In [None]:

import plotly.graph_objects as go
import cv2,nibabel as nib
"""  Plot the LAX image with the normal vector, mitral valve insertion points, and mitral annulus approximation. """
dcm_file = dcm_lax[0]
P = get_pixel_coordinates(dcm_file)
X = P[:,:,0]
Y = P[:,:,1]
Z = P[:,:,2]    
img = P[:,:,3]
img = cv2.convertScaleAbs(img, alpha=255/img.max())

# Define the two mitral valve insertion points
i, j =  int(np.array(mv_insert_lateral[0][1])//2) , int(np.array(mv_insert_lateral[0][2])//2) 
i2, j2 = int(np.array(mv_insert_septal[0][1])//2) , int(np.array(mv_insert_septal[0][2])//2)
physical_coords = get_3d_coordinates(dcm_file, i, j)
physical_coords2 = get_3d_coordinates(dcm_file, i2, j2)

P1 = np.array(physical_coords)  # Mitral valve insertion point 1 (lateral)
P2 = np.array(physical_coords2)  # Mitral valve insertion point 2 (septal)

# Vector between P1 and P2 
v = P2 - P1

image_orientation = dcm_file.ImageOrientationPatient  

R = np.array(image_orientation[:3])  
C = np.array(image_orientation[3:])  
normal_vector = np.cross(R, C)
normal_vector /= np.linalg.norm(normal_vector)
w = np.cross(normal_vector, v)
w /= np.linalg.norm(w)  

center = (P1 + P2) / 2
semi_major_axis_length = np.linalg.norm(v)/2
semi_minor_axis_length = np.linalg.norm(v)/2
num_points = 100
theta = np.linspace(0, 2 * np.pi, num_points)
r = np.linspace(0, 1, num_points)
theta, r = np.meshgrid(theta, r)
x_ellipse = r * semi_major_axis_length * np.cos(theta)
y_ellipse = r * semi_minor_axis_length * np.sin(theta)

t = np.linspace(0, 2 * np.pi, 100)
ellipse_x = center[0] + semi_major_axis_length * np.cos(t) * (v / np.linalg.norm(v))[0] + semi_minor_axis_length * np.sin(t) * normal_vector[0]
ellipse_y = center[1] + semi_major_axis_length * np.cos(t) * (v / np.linalg.norm(v))[1] + semi_minor_axis_length * np.sin(t) * normal_vector[1]
ellipse_z = center[2] + semi_major_axis_length * np.cos(t) * (v / np.linalg.norm(v))[2] + semi_minor_axis_length * np.sin(t) * normal_vector[2]

# Transform points to 3D space
ellipse_points_x = center[0] + x_ellipse * (v / np.linalg.norm(v))[0] + y_ellipse * normal_vector[0]
ellipse_points_y = center[1] + x_ellipse * (v / np.linalg.norm(v))[1] + y_ellipse * normal_vector[1]
ellipse_points_z = center[2] + x_ellipse * (v / np.linalg.norm(v))[2] + y_ellipse * normal_vector[2]

# Create a surface plot for the LAX image
fig = go.Figure(data=[go.Surface(x=X, y=Y, z=Z, surfacecolor=img, colorscale="Greys_r", showscale=False, showlegend=True, name="LAX")])
fig.update_traces(showscale=False)

for i_sax in choose_sax:
    dcm_file = dcm_sax[i_sax]
    P = get_pixel_coordinates(dcm_file)
    X = P[:,:,0]
    Y = P[:,:,1]
    Z = P[:,:,2]
    img = P[:,:,3]
    img = cv2.convertScaleAbs(img, alpha=255/img.max())
    fig.add_surface(x=X, y=Y, z=Z,
                    name=str(f"SAX{i_sax}"),
                    surfacecolor=img,
                    colorscale="Greys_r",
                    showlegend=True,
                    opacity=1.0,
                    showscale=False)

# Plot the normal vector with an arrow
fig.add_trace(go.Cone(
    x=[center[0]], y=[center[1]], z=[center[2]],
    u=[normal_vector[0]], v=[normal_vector[1]], w=[normal_vector[2]],
    sizemode="absolute",
    sizeref=5,
    anchor="tail",
    colorscale=[[0, 'red'], [1, 'red']],
    showscale=False,
    name="Normal Vector"
))


fig.add_trace(go.Scatter3d(x=[P1[0], P2[0]], y=[P1[1], P2[1]], z=[P1[2], P2[2]],
                           mode='markers', marker=dict(size=5, color='red'),
                           name="Mitral Valve Insertion Points"))

# Plot mitral annulus approximation
fig.add_trace(go.Scatter3d(x=ellipse_x, y=ellipse_y, z=ellipse_z,
                           mode='lines', line=dict(color='blue', width=4),
                           name="Mitral Annulus Ellipse"))

# Plot the points inside the disk
fig.add_trace(go.Scatter3d(x=ellipse_points_x.flatten(), y=ellipse_points_y.flatten(), z=ellipse_points_z.flatten(),
                           mode='markers', marker=dict(size=2, color='blue', opacity=0.5),
                           name="Points Inside Ellipse"))

fig.update_layout(title="LAX Plane with Normal Vector, Mitral Valve Insertion Points, and Mitral Annulus",
                  scene=dict(aspectmode='data'), width=800, height=600)
slider_sax = {
    
    'active': 100,
    'currentvalue': {'prefix': 'Opacity: '},
    'steps': [{
        'value': step/100,
        'label': f'{step}%',
        'visible': True,
        'execute': True,
        'method': 'restyle',
        'args': [{'opacity': step/100}, [i for i in range(1, n_max_sax+1)]]     # apply to sax only
    } for step in range(101)]
}

fig.layout.scene.camera.projection.type = "orthographic"

fig.update_layout(sliders=[slider_sax])

fig.show()

In [18]:
import numpy as np
import nibabel as nib
import os
from pydicom import dcmread
import plotly.graph_objects as go
import cv2
from scipy.ndimage import zoom

"More efficient plotting of LAX, SAX and  SAX Masks"
def get_dicom_affine(dcm_file):
    ipp = np.array(dcm_file.ImagePositionPatient, dtype=float)
    iop = np.array(dcm_file.ImageOrientationPatient, dtype=float)
    ps = np.array(dcm_file.PixelSpacing, dtype=float)
    
    X_iop = iop[:3]
    Y_iop = iop[3:]
    
    A = np.array([
        [X_iop[0] * ps[1], Y_iop[0] * ps[0], 0, ipp[0]],
        [X_iop[1] * ps[1], Y_iop[1] * ps[0], 0, ipp[1]],
        [X_iop[2] * ps[1], Y_iop[2] * ps[0], 0, ipp[2]],
        [0, 0, 0, 1]
    ])  
    return A

def transform_nii_to_global(segmentation, dicom_affine):
    height, width = segmentation.shape
    coords = np.mgrid[0:height, 0:width].reshape(2, -1)
    ones = np.ones((1, coords.shape[1]))
    coords = np.vstack((coords, np.zeros((1, coords.shape[1])), ones))  
    global_coords = np.dot(dicom_affine, coords)
    segmentation_values = segmentation.flatten()

    return global_coords.reshape(4, height, width), segmentation_values.reshape(height, width)


def plot_lax_with_segmentation(dcm_sax, dcm_lax, segmentation_file):

    seg_nii = nib.load(segmentation_file)
    segmentation = seg_nii.get_fdata()

    fig = go.Figure()
    dcm_file = dcm_lax[0]
    P = get_pixel_coordinates(dcm_file)
    X, Y, Z, img = P[:, :, 0], P[:, :, 1], P[:, :, 2], P[:, :, 3]
    img = cv2.convertScaleAbs(img, alpha=255 / img.max())

    fig.add_surface(x=X, y=Y, z=Z, surfacecolor=img, colorscale="Greys_r", showscale=False, showlegend=True, name="LAX")

    for i_sax in range(len(dcm_sax)):
        dcm_file = dcm_sax[i_sax]
        P = get_pixel_coordinates(dcm_file)
        X, Y, Z, img = P[:, :, 0], P[:, :, 1], P[:, :, 2], P[:, :, 3]
        img = cv2.convertScaleAbs(img, alpha=255 / img.max())
        
       
        fig.add_surface(x=X, y=Y, z=Z, surfacecolor=img, colorscale="Greys_r", opacity=0.8, showscale=False, showlegend=False)

        dicom_affine = get_dicom_affine(dcm_file)
        new_height, new_width = dcm_file.pixel_array.shape
        
        seg = np.transpose(segmentation[:, :, i_sax], (1, 0))
        
       
        if seg.shape != (new_height, new_width):
            zoom_factors = (new_height / seg.shape[0], new_width / seg.shape[1])
            reshaped_mask = zoom(seg, zoom_factors, order=0)  
        else:
            reshaped_mask = seg
        
        segmentation_global, seg_values = transform_nii_to_global(reshaped_mask, dicom_affine)
        X_seg, Y_seg, Z_seg = segmentation_global[0], segmentation_global[1], segmentation_global[2]
        

        fig.add_surface(x=X_seg, y=Y_seg, z=Z_seg, surfacecolor=seg_values, colorscale="Viridis", opacity=0.4, showscale=False, showlegend=False)

 
    fig.update_layout(title="LAX View with Segmentation", scene=dict(aspectmode='data'), width=800, height=600)
    fig.layout.scene.camera.projection.type = "orthographic"
    

    fig.show(renderer='browser')


dicom_directory = f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/{subject_id}/DICOM_files"
segmentation_file = f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/{subject_id}/NIfTI_files/{subject_id}_myo.nii"


# plot_lax_with_segmentation(dcm_sax, dcm_lax, segmentation_file)


In [26]:
" Save the global segmentation in 3D space: SAX annotations, resolution and original segmentation"

def get_value_for_case(case_name):
        """
        Reads the Excel file and returns the value for the given case_name.
        Skips the first row which contains titles.
        """
        case_name = int(case_name)
        input_file="/Users/giuliamonopoli/Desktop/PhD /shaping_mad/slice_gap.xlsx"
        try:
            df = pd.read_excel(input_file,names=["Case", "Value"])
            row = df.loc[df['Case'] == case_name]
            if not row.empty:
                value = row['Value'].values[0]  
                return int(value) if pd.notnull(value) else None
            else:
                return None  
        except Exception as e:
            print(f"Error reading the file: {e}")
            return None
        
def save_global_sax(patient, dcm_sax, segmentation_file):
    seg_nii = nib.load(segmentation_file)
    segmentation = seg_nii.get_fdata()

    X_seg_stack, Y_seg_stack, Z_seg_stack, seg_values_stack = [], [], [], []
    x_mask_coords, y_mask_coords, z_mask_coords = [], [], [] 
    slice_gap = get_value_for_case(patient)
        
    
    for i_sax in range(len(dcm_sax)):
        dcm_file = dcm_sax[i_sax]
        dicom_affine = get_dicom_affine(dcm_file)
        new_height, new_width = dcm_file.pixel_array.shape
     
        seg = np.transpose(segmentation[:, :, i_sax], (1, 0))
        
        if seg.shape != (new_height, new_width):
            zoom_factors = (new_height / seg.shape[0], new_width / seg.shape[1])
            reshaped_mask = zoom(seg, zoom_factors, order=0)
        else:
            reshaped_mask = seg
        
        # Transform segmentation to global coordinates
        segmentation_global, seg_values = transform_nii_to_global(reshaped_mask, dicom_affine)
        X_seg, Y_seg, Z_seg = segmentation_global[0], segmentation_global[1], segmentation_global[2]
        
        X_seg_stack.append(X_seg)
        Y_seg_stack.append(Y_seg)
        Z_seg_stack.append(Z_seg)
        seg_values_stack.append(seg_values)

        # Get indices where the mask is greater than 0
        mask_indices = reshaped_mask > 0
        
        # Save the coordinates where the mask is > 0
        x_mask_coords.append(X_seg[mask_indices])
        y_mask_coords.append(Y_seg[mask_indices])
        z_mask_coords.append(Z_seg[mask_indices])

    X_seg_stack = np.stack(X_seg_stack, axis=-1)
    Y_seg_stack = np.stack(Y_seg_stack, axis=-1)
    Z_seg_stack = np.stack(Z_seg_stack, axis=-1)
    seg_values_stack = np.stack(seg_values_stack, axis=-1)

    output_path = os.path.join(output_folder, f"{patient}_global_segmentation_3D.h5")

    resolutions = float(dcm_file.PixelSpacing[0]), float(dcm_file.PixelSpacing[1]), slice_gap
       
    try:
        with h5py.File(output_path, 'w') as f:
            f.create_dataset('original_segmentation', data=segmentation)
            f.create_dataset('resolution', data=resolutions)
            f.create_dataset('x_mask_coords', data=np.concatenate(x_mask_coords))
            f.create_dataset('y_mask_coords', data=np.concatenate(y_mask_coords))
            f.create_dataset('z_mask_coords', data=np.concatenate(z_mask_coords))
        
        print(f"Saved all data to: {output_path}")
    except Exception as e:
        print(f"Error saving file: {e}")


dicom_directory = f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/{subject_id}/DICOM_files"
segmentation_file = f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/{subject_id}/NIfTI_files/{subject_id}_myo.nii"
output_folder = f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/{subject_id}/Global_Segmentations"
os.makedirs(output_folder, exist_ok=True)


save_global_sax(subject_id, dcm_sax, segmentation_file)


Saved all data to: /Users/giuliamonopoli/Desktop/PhD /Data/ES_files/27/Global_Segmentations/27_global_segmentation_3D.h5


In [20]:
" check h5 file"

def load_mask_coordinates(file_path):
    with h5py.File(file_path, 'r') as f:
        x_mask_coords = f['x_mask_coords'][:]
        y_mask_coords = f['y_mask_coords'][:]
        z_mask_coords = f['z_mask_coords'][:]
    return x_mask_coords, y_mask_coords, z_mask_coords

def plot_mask_coordinates(x_mask_coords, y_mask_coords, z_mask_coords):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x_mask_coords, y_mask_coords, z_mask_coords, c='r', s=1, alpha=0.5)  # Adjust color and size as needed
    ax.set_xlabel('X Coordinate')
    ax.set_ylabel('Y Coordinate')
    ax.set_zlabel('Z Coordinate')
    ax.set_title('3D Scatter Plot of Mask Coordinates')
    plt.show()


mask_coords_file_path = f"/Users/giuliamonopoli/Desktop/PhD /Data/ES_files/27/Global_Segmentations/27_global_segmentation_3D.h5"


# Load and plot the mask coordinates
x_mask_coords, y_mask_coords, z_mask_coords = load_mask_coordinates(mask_coords_file_path)
plot_mask_coordinates(x_mask_coords, y_mask_coords, z_mask_coords)
