In [1]:
import os
from pymskt.mesh import Mesh
import SimpleITK as sitk
import numpy as np
import nibabel as nib

from pymskt.image import read_nrrd
from pymskt.mesh.meshTransform import SitkVtkTransformer
from pymskt.mesh.meshTools import ProbeVtkImageDataAlongLine
from pymskt.mesh.meshTools import get_surface_normals, n2l, l2n
from pymskt.mesh.utils import is_hit, get_intersect, get_surface_normals, get_obb_surface

In [None]:
def categorize_subject_data(subject_file_names):

    '''
    function which takes in a list of all the subject file names and creates a nested dictionary with 
    'Subject_number', 'VISIT-NUMBER' and 'knee type' (injured or contralateral) for each subject
    '''

    subject_numbers= []
    for file in subject_file_names:
        subject_numbers.append(file[-35:-31])
    
    # remove duplicate numbers and sort in ascending order
    subject_numbers= list(set(subject_numbers))
    subject_numbers.sort()

    # dictionary with subject number, visit number and knee file names
    subject_data={}

    for index, subject_number in enumerate(subject_numbers):

        # collect all the files for each subject
        subject_data[subject_number] = [file for file in subject_file_names if file[-35:-31] == subject_number]
        
        # define a new dictioary, visit, whose keys are 'VISIT-1', 'VISIT-2', 'VISIT-3','VISIT-4', and 'VISIT-5'       
        visit={}

        for visit_number in range(5):

            # key = VISIT-X
            key= f'VISIT-{visit_number+1}'

            # collect all the files for each visit
            visit[key]= [file for file in subject_data[subject_number] if key in file]

            # define a new dictioary knee, whose keys are ' injured' and 'contralateral'
            knee= {} 

            if len(visit[key])==2: # check if each visit has 2 files (for injured and contralateral)
                
                if subject_number == '24-P' and key == 'VISIT-1': # exception case

                    # the higher exam number is the contralateral knee
                    if visit[key][0][-5:] < visit[key][1][-5:]: 
                        knee['injured'] = visit[key][1]
                        knee['contralateral'] = visit[key][0]
                    else:
                        knee['injured'] = visit[key][0]
                        knee['contralateral'] = visit[key][1]
                
                else:

                    # the higher exam number is the injured knee and the other contralateral
                    if visit[key][0][-5:] < visit[key][1][-5:]: 
                        knee['injured'] = visit[key][0]
                        knee['contralateral'] = visit[key][1]
                    else:
                        knee['injured'] = visit[key][1]
                        knee['contralateral'] = visit[key][0]

            elif len(visit[key])==1:
                knee['NA'] = visit[key][0]
        
            visit[key]= knee

        subject_data[subject_number] = visit

         # remove 'VISIT-X' if there is no scan
        subject_data[subject_number]= {k: v for k, v in subject_data[subject_number].items() if v}

    return subject_data

In [None]:
os.environ['DIR']= '/dataNAS/people/anoopai/DESS_ACL_study'
os.environ['DATA']= 'data_processed'
os.environ['LOG']= 'notebooks_dosma_registration_pipeline/logs'
os.environ['FILES']= 'notebooks_dosma_registration_pipeline/files'
os.environ['RESULTS']= 'results'

# Change working directory to the directory containing data
dir_path = os.environ['DIR']
log_path= (os.path.join(os.environ['DIR'], os.environ['LOG']))
file_path= (os.path.join(os.environ['DIR'], os.environ['FILES']))
data_dir_path = (os.path.join(os.environ['DIR'], os.environ['DATA']))
results_path = (os.path.join(os.environ['DIR'], os.environ['RESULTS']))
os.chdir(data_dir_path)

# Get list of all folders (patient-visit-leg) in the data directory
os.chdir(data_dir_path)
dirs = os.listdir(data_dir_path)
dirs = [item for item in dirs if 'stdout.nipype' not in item]
dirs = [item for item in dirs if '.nfs00' not in item]

# Get directory names and categorise them as Patient or Control
subject_files=[]
patient_file_names=[]
control_file_names=[]

for index, dir in enumerate(dirs):
    # sub_dirs= aa= '\t'.join(os.listdir(dir)) # joins all the element of the list using \t chracter (white space)

    # if "Sag_DESS_HR_OneTouch15660-3132-p" in os.listdir(dir): # check if qDESS file exists in the folder : RAW FILES ONLY
    if "results" in os.listdir(dir):
        
        # list with all the useful files (which contains segmentations)
        subject_files.append(dir)

        if '-P-' in dir:
            patient_file_names.append(dir)
        elif '-C-' in dir:
            control_file_names.append(dir)
        else:
            continue
        
    patient_file_names.sort()
    control_file_names.sort()

    patient_data= categorize_subject_data(patient_file_names)
    control_data= categorize_subject_data(control_file_names)

In [4]:
scan0 = patient_data['27-P']['VISIT-1']['injured']
scan1 = patient_data['27-P']['VISIT-3']['injured']

seg_mask_type= 'dosma' 
status='reg2baseline' #'reg2timepoint'
quant_type= 't2' #'t1_rho'
cluster_type= 'pos' # 'neg']
data_name= f'{quant_type}_{status}_{seg_mask_type}'

timepoint0 = os.path.join(data_dir_path, scan0)
timepoint1 = os.path.join(data_dir_path, scan1)

fc_path0 = os.path.join(timepoint0, f'results/{data_name}/fc/fc.nii.gz')
t2_path0 = os.path.join(timepoint0, f'results/{data_name}/fc/{quant_type}/{quant_type}_filtered.nii.gz')
fc_path1 = os.path.join(timepoint1, f'results/{data_name}/fc/fc.nii.gz')
t2_path1 = os.path.join(timepoint1, f'results/{data_name}/fc/{quant_type}/{quant_type}_filtered.nii.gz')
diff_maps_path = os.path.join(timepoint1, f'results/cluster_analysis/{data_name}/difference_maps.nii.gz')
intensity_thresh_path = os.path.join(timepoint1, f'results/cluster_analysis/{data_name}/difference_maps_intensity_threshold_{cluster_type}.nii.gz')
volume_thresh_path = os.path.join(timepoint1, f'results/cluster_analysis/{data_name}/cluster_{cluster_type}_all/fc/{quant_type}/{quant_type}.nii.gz')

fc_save_path0 = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/fc0.nrrd')
t2_save_path0 = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/t20.nrrd')
fc_save_path1 = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/fc1.nrrd')
t2_save_path1 = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/t21.nrrd')
diff_maps_save_path = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/diff_maps.nrrd')
intensity_thresh_save_path = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/intensity_threshold.nrrd')
volume_thresh_save_path = os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/volume_threshold.nrrd')

t2_3D_save_path0= os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/t20.vtk')
t2_3D_save_path1= os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/t21.vtk')
diff_maps_3D_save_path= os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/diff_maps.vtk')
intensity_thresh_3D_save_path= os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/intensity_threshold.vtk')
volume_thresh_3D_save_path= os.path.join(results_path, f'difference_maps/cluster_results_all/surface_plot_data/volume_threhsold.vtk')

In [5]:
"""
To apply this pipeline to other data, you will need to:
1. change fc_path to be the appropriate segmentation file
2. change t2_path to whatever you want to use for analysis. E.g.
    - difference_maps.nii
    - pre or post T2
    - cluster maps
    - etc.
"""
fc0 = sitk.ReadImage(fc_path0)
sitk.WriteImage(fc0, fc_save_path0)

t20 = sitk.ReadImage(t2_path0)
sitk.WriteImage(t20, t2_save_path0)

fc1 = sitk.ReadImage(fc_path1)
sitk.WriteImage(fc1, fc_save_path1)

t21 = sitk.ReadImage(t2_path1)
sitk.WriteImage(t21, t2_save_path1)

diff_maps= sitk.ReadImage(diff_maps_path)
sitk.WriteImage(diff_maps, diff_maps_save_path)

intensity_thresh= sitk.ReadImage(intensity_thresh_path)
sitk.WriteImage(intensity_thresh, intensity_thresh_save_path)

volume_thresh= sitk.ReadImage(volume_thresh_path)
sitk.WriteImage(volume_thresh, volume_thresh_save_path)

In [6]:
def create_surace_plot(t2_path, fc_path, save_path):
     
    import os
    from pymskt.mesh import Mesh
    import SimpleITK as sitk
    import numpy as np

    from pymskt.image import read_nrrd
    from pymskt.mesh.meshTransform import SitkVtkTransformer
    from pymskt.mesh.meshTools import ProbeVtkImageDataAlongLine
    from pymskt.mesh.meshTools import get_surface_normals, n2l, l2n
    from pymskt.mesh.utils import is_hit, get_intersect, get_surface_normals, get_obb_surface
        
    # Create a surface mesh from the segmentation file. 

    # smooth_image_var will dicatate how smooth the surface is
    # if parts of the surface are missing, try decreasing this value. 
    # if it appears to jagged, try increasing this value.

    mesh = Mesh(path_seg_image=fc_path, label_idx=1)
    mesh.create_mesh(smooth_image_var=0.2)

    # resample surface will reduce the number of points in the mesh
    # and separate them equally over the surface. 
    mesh.resample_surface(clusters=10000)

    # read the t2 image data in sitk. 
    sitk_image = sitk.ReadImage(t2_path)

    # read the t2 image data in vtk. - set origin to zero so that its at the origin
    # it doesnt account for rotations (now) and so easiest to align with the mesh
    # by undoing its translation, and then undoing rotation & translation for the mesh
    nrrd_t2 = read_nrrd(t2_path, set_origin_zero=True).GetOutput()

    # apply inverse transform to the mesh (so its also at the origin)
    nrrd_transformer = SitkVtkTransformer(sitk_image)
    mesh.apply_transform_to_mesh(transform=nrrd_transformer.get_inverse_transform())

    # setup the probe that we are using to get data from the T2 file 
    line_resolution = 10000   # number of points along the line that the T2 data is sampled at
    filler = 0              # if no data is found, what value to fill the data with
    ray_length= -10          # how far to extend the ray from the surface (using negative to go inwards/towards the other side)
    percent_ray_length_opposite_direction = 1.0  # extend the other way a % of the line to make sure get both edges. 1.0 = 100%|

    data_probe = ProbeVtkImageDataAlongLine(
        line_resolution,
        nrrd_t2,
        save_mean=True,         # if we want mean. 
        save_max=True,          # if we want max
        save_std=False,         # if we want to see variation in the data along the line. 
        save_most_common=False, # for segmentations - to show the regions on the surface. 
        filler=filler
    )

    # get the points and normals from the mesh - this is what we'll iterate over to apply the probe to. 
    points = mesh.mesh.GetPoints()
    normals = get_surface_normals(mesh.mesh)
    point_normals = normals.GetOutput().GetPointData().GetNormals()

    # create an bounding box that we can query for intersections.
    obb_cartilage = get_obb_surface(mesh.mesh)

    # iterate over the points & their normals. 
    for idx in range(points.GetNumberOfPoints()):
        # for each point get its x,y,z and normal
        point = points.GetPoint(idx)
        normal = point_normals.GetTuple(idx)

        # get the start/end of the ray that we are going to use to probe the data.
        # this is based on the ray length info defind above. 
        end_point_ray = n2l(l2n(point) + ray_length*l2n(normal))
        start_point_ray = n2l(l2n(point) + ray_length*percent_ray_length_opposite_direction*(-l2n(normal)))

        # get the number of intersections and the cell ids that intersect.
        points_intersect, cell_ids_intersect = get_intersect(obb_cartilage, start_point_ray, end_point_ray)

        # if 2 intersections (the inside/outside of the cartilage) then probe along the line between these
        # intersections. Otherwise, fill the data with the filler value.
        if len(points_intersect) == 2:
            # use the intersections, not the ray length info
            # this makes sure we only get values inside of the surface. 
            start = np.asarray(points_intersect[0])
            end = np.asarray(points_intersect[1])

            start = start + (start-end) * 0.1
            end = end + (end-start) * 0.1
            data_probe.save_data_along_line(start_pt=start,
                                            end_pt=end)
        else:
            data_probe.append_filler()

    # undo the transforms from above so that the mesh is put back to its original position.
    mesh.reverse_all_transforms()

    mesh.set_scalar('t2_max', data_probe.max_data)
    mesh.set_scalar('t2_mean', data_probe.mean_data)
    mesh.save_mesh(save_path)

In [7]:
create_surace_plot(t2_save_path0, fc_save_path0, t2_3D_save_path0)
create_surace_plot(t2_save_path1, fc_save_path1, t2_3D_save_path1)
create_surace_plot(diff_maps_save_path, fc_save_path1, diff_maps_3D_save_path)
create_surace_plot(intensity_thresh_save_path, fc_save_path1, intensity_thresh_3D_save_path)
create_surace_plot(volume_thresh_save_path, fc_save_path1, volume_thresh_3D_save_path)