In [None]:
%load_ext autoreload
%autoreload 2
#%matplotlib widget
#%matplotlib ipympl

#%reload_ext tensorboard
#%matplotlib qt

In [None]:
import os
import time
from pathlib import Path
from datetime import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from tqdm.notebook import tqdm
import pickle, subprocess
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram
from scipy.ndimage import label as scipy_label
import torch
import sklearn
import csv
import gc
import pydicom
import networkx as nx
import copy
import scipy
#from radiomics import featureextractor
#import radiomics

import glob
from platipy.imaging import ImageVisualiser
from platipy.dicom.io.rtstruct_to_nifti import convert_rtstruct, read_dicom_image

#%matplotlib notebook
#%matplotlib widget
#plt.ion()
#import initial_ml as iml

In [None]:

data_directory = '../../data/LIDC-IDRI/'
out_directory = '../../data/LIDC-IDRI/ct_slices/'
nii_directory = '../../data/LIDC-IDRI/Nii'

data_path = Path(data_directory)
out_path = Path(out_directory)
nii_path = Path(nii_directory)
embedding_path = data_path.joinpath('embeddings')

embedding_path.mkdir(exist_ok=True, parents=True)
out_path.mkdir(exist_ok=True, parents=True)
nii_path.mkdir(exist_ok=True, parents=True)
meta_ct_df = pd.read_csv(data_path.joinpath('ct_dicom/metadata.csv'))
meta_seg_df = pd.read_csv(data_path.joinpath('seg_dicom/metadata.csv'))

meta_ct_df.set_index('Subject ID', inplace=True)
meta_seg_df.set_index('Subject ID', inplace=True)

ct_slices = pd.read_pickle('./ct_slices_dict.pkl')

## Slice extraction
The following blocks are for collecting the slice information from the SEG files.
It goes through each SEG file and gets the Instance UIDs for the marked slices, and then keeps only the unique UIDs (since it is looking through multiple SEGs for the same nodes).
Additionally it calculates the center of mass for each node.

In [None]:
# pull uids for segment slices - to be used to pull corresponding slices from the CT images
segment_slices = {}
segment_com = {}
for pat, df_group in tqdm(meta_seg_df.groupby("Subject ID")):
    segment_slices[pat] = {}
    segment_com[pat] = {}
    # Select the structure set with the later date
    for idx, seg_row in df_group[df_group.Modality == "SEG"].iterrows():
        seg_dir = data_path.joinpath('seg_dicom').joinpath(seg_row["File Location"].replace('\\','/')) 
        seg_num = seg_dir.as_posix().split('/')[-1].split('-')[1].split(' ')[-2]
        segment_slices[pat][seg_num] = []
        segment_com[pat][seg_num] = []
    for idx, seg_row in df_group[df_group.Modality == "SEG"].iterrows():
        seg_dir = data_path.joinpath('seg_dicom').joinpath(seg_row["File Location"].replace('\\','/')) 

        seg_num = seg_dir.as_posix().split('/')[-1].split('-')[1].split(' ')[-2]
        seg_file = pydicom.dcmread(seg_dir.joinpath('1-1.dcm'))

        segment_com[pat][seg_num].append(np.array(scipy.ndimage.center_of_mass(seg_file.pixel_array)))
        #print(segment_com[pat][seg_num])
        for uid_source in seg_file.ReferencedSeriesSequence[0].ReferencedInstanceSequence:
            segment_slices[pat][seg_num].append(str(uid_source.ReferencedSOPInstanceUID))

    for pat in segment_slices.keys():
        for seg_num in segment_slices[pat].keys():
            segment_slices[pat][seg_num] = np.unique(segment_slices[pat][seg_num])
            #print(segment_slices[pat][seg_num])
            n_slices = len(segment_slices[pat][seg_num])
            com_tmp = segment_com[pat][seg_num]
            #print(com_tmp)
            new_com = []
            for idx, com in enumerate(com_tmp):
                #print(com)
                if len(com) < 3:
                    new_com.append(np.array([np.float64(n_slices/2.), com[0], com[1]]))
                else:
                    new_com.append(np.array([np.float64(n_slices/2.), com[1], com[2]]))

            segment_com[pat][seg_num] = new_com
            #segment_com[pat][seg_num] = np.array(segment_com[pat][seg_num]).mean(axis=0)

In [None]:
# Rearrane the COMs so that they are consistent in dimension, and reduce the number of COMs to one COM per node.
for pat in segment_com.keys():
    print(pat)
    for node in segment_com[pat].keys():
        print(f'    {node}')
        print(segment_com[pat][node])
        segment_com[pat][seg_num] = np.array(segment_com[pat][seg_num]).mean(axis=0)

In [None]:
# takes the slice UIDs from 'segment_slices' and associates them to a specific dicom file. 
# Those files are then stored in a new dictionary 'ct_slices'
ct_slices_tmp = {}
ct_slices = {}
selected_rows = []

for pat, df_group in tqdm(meta_ct_df.groupby("Subject ID")):
    if pat not in list(segment_slices.keys()): continue
    ct_slices_tmp[pat] = {}
    ct_slices[pat] = {}
    for idx, ct_row in df_group[df_group.Modality == "CT"].iterrows():
        ct_dir = data_path.joinpath('ct_dicom').joinpath(ct_row["File Location"].replace('\\','/'))
        for node in segment_slices[pat].keys():
            ct_slices_tmp[pat][node] = []
            ct_slices[pat][node] = []
            
        for ct_file in ct_dir.glob('*'):
            if '.xml' in ct_file.as_posix(): continue
            ct_slice = pydicom.dcmread(ct_file)
            instance_uid = ct_slice.SOPInstanceUID
            for node in segment_slices[pat].keys():
                if instance_uid in segment_slices[pat][node]:
                    ct_slices_tmp[pat][node].append(ct_file.with_suffix('').name.split('-')[-1])
        for node in ct_slices_tmp[pat].keys():
            ct_slices_tmp[pat][node] = sorted(ct_slices_tmp[pat][node])
            ct_slices[pat][node] = ct_slices_tmp[pat][node]
            if len(ct_slices_tmp[pat][node]) < 1:
                ct_slices[pat].pop(node)

ct_slices

In [None]:
#saving the dictionary containing dicom files for each node
with open("ct_slices_dict.pkl", 'wb') as f:
    pickle.dump(ct_slices, f)
    pickle.dump(segment_slices, f)

In [None]:
# get the slices for each node and save in NifTi format
problem_images = []
for patient in tqdm(ct_slices.keys()):
    #if int(patient.split('-')[-1]) < 315: continue
    pat_df = meta_ct_df.loc[[patient]]
    patient_nii_path = nii_path.joinpath(patient)
    patient_nii_path.mkdir(exist_ok=True, parents=True)

    #Convert the CT Image
    ct_row = pat_df[pat_df["Modality"] == "CT"]
    ct_directory = data_path.joinpath('ct_dicom').joinpath(ct_row["File Location"].iloc[0].replace('\\','/'))
    ct_image = read_dicom_image(ct_directory)
    axial_size = ct_image.GetDepth()
    print(axial_size)
    for node in ct_slices[patient].keys():
        output_file = patient_nii_path.joinpath(f"image_{node}.nii.gz")
        node_slices_tmp = [int(idx) for idx in ct_slices[patient][node]]
        slices = axial_size - np.array(node_slices_tmp)
        print(patient, node)
        print(ct_slices[patient][node])
        print(slices)
        print(ct_image.GetSize())
        modified_image = ct_image[:,:, slices.min():slices.max()+1]
        print(modified_image.GetSize())
        if modified_image.GetSize()[-1] < 1:
            problem_images.append((patient, node))
            continue
        sitk.WriteImage(modified_image, str(output_file))
        
        #print(ct_image.GetSize())
        #print(modified_image.GetSize())


## Embedding calculation
The following blocks calculate foundation embeddings using the Harvard AIM CT foundation model.

In [None]:
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)
from monai.inferers import SlidingWindowInferer

In [None]:
# Setup the foundation model as a feature extractor, and prepare the preprocessing step
model = SegResEncoder.from_pretrained(
    "project-lighter/ct_fm_feature_extractor"
)
model.eval()
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])

In [None]:
#Loop through dictionary of CT slices and calucate the embedding. Embeddings are dumped into a pickle file, one per node


lung_embeddings = {}

for patient in tqdm(ct_slices.keys()):
    lung_embeddings[patient] = {}
    if int(patient.split('-')[-1]) < 340: continue
    print(patient)
    patient_emb_path = embedding_path.joinpath(patient)
    patient_emb_path.mkdir(exist_ok=True, parents=True)
    for node in ct_slices[patient].keys():
        print(f'    {node}')
        input_path = nii_path.joinpath(patient).joinpath(f'image_{node}.nii.gz')
        input_tensor = preprocess(input_path.as_posix())
        with torch.no_grad():
            output = model(input_tensor.unsqueeze(0))[-1]

            # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and 
            # use the output tensor directly which will give you the feature maps in a low-dimensional space.
            avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()
            #lung_embeddings[patient][node] = avg_output
            with open(patient_emb_path.joinpath(f"embedding_{patient}_{node}.pkl"), "wb") as f:
                pickle.dump(avg_output, f)
                f.close()
        del(input_tensor)
        del(avg_output)


## Resampling code

In [None]:
resampler = sitk.ResampleImageFilter()
resampler.SetOutputDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
resampling = [1,1,1]
resampler.SetOutputSpacing(resampling)

In [None]:
def get_bouding_boxes(ct, pt):
    """
    Get the bounding boxes of the CT and PT images.
    This works since all images have the same direction
    """

    ct_origin = np.array(ct.GetOrigin())
    pt_origin = np.array(pt.GetOrigin())

    ct_position_max = ct_origin + np.array(ct.GetSize()) * np.array(
        ct.GetSpacing())
    pt_position_max = pt_origin + np.array(pt.GetSize()) * np.array(
        pt.GetSpacing())
    return np.concatenate(
        [
            np.maximum(ct_origin, pt_origin),
            np.minimum(ct_position_max, pt_position_max),
        ],
        axis=0,
    )

In [None]:
def resample_one_patient(p):
    pat_str = p.as_posix().split('/')[-1]
    patient_resample_path = resample_path.joinpath(pat_str)
    patient_resample_path.mkdir(exist_ok=True, parents=True)
    try:
        ct = sitk.ReadImage(p.joinpath('image.nii.gz').as_posix())
    except:
        print(f"    unable to read image file for {pat_str}")
        #os.rmdir(p)
        #os.rmdir(patient_resample_path)
        #print(f"{pat_str} folder removed due to being empty")
        return
    #label = sitk.ReadImage(os.path.join(savePath, p, 'mask_GTVp.nii.gz'))
    bb = get_bouding_boxes(ct, ct)
    size = np.round((bb[3:] - bb[:3]) / resampling).astype(int)
    resampler.SetOutputOrigin(bb[:3])
    resampler.SetSize([int(k) for k in size])  # sitk is so stupid
    resampler.SetInterpolator(sitk.sitkBSpline)
    ct = resampler.Execute(ct)

    #sitk.WriteImage(ct, patient_resample_path.joinpath('image.nii.gz').as_posix())
    resampler.SetInterpolator(sitk.sitkNearestNeighbor)

    mask_sizes = []
    for m in p.glob('*.nii.gz'):
        if 'image' in str(m): continue
        label = sitk.ReadImage(m.as_posix())
        label = resampler.Execute(label)

        label_array = sitk.GetArrayViewFromImage(label)
        label_locations = np.where(label_array > 0)
        mask_sizes.append(np.max(label_locations, axis=1) - np.min(label_locations, axis=1))
        #sitk.WriteImage(label, patient_resample_path.joinpath(m.as_posix().split('/')[-1]).as_posix())
    return mask_sizes

## 4. Cropping

In [None]:
def tune_range(min_d, max_d, d, size_d, p):
    min_pad = 0
    max_pad = 0
    if min_d<0:
        min_pad = abs(min_d)
        min_d = 0
        #max_d = min_d + d
        #if max_d - size_d > 0:
        #    max_pad = max_d-size_d
            
        #assert (max_d<size_d), f"Cannot extract the patch with the shape {size_d} from the image with the shape {d} for patient {p}."
    
    if max_d>d:
        max_pad = max_d - d
        max_d = d
        #min_d = max_d - size_d
        #if min_d < 0:
        #    min_pad = abs(min_d)
            
        #assert (min_d>0), f"Cannot extract the patch with the shape {size_d} from the image with the shape {d} for patient {p}."

    return min_d, max_d, int(min_pad), int(max_pad)
physical_locations = {}
for p_dir in tqdm(list(resample_path.glob('*'))):
    p_str = p_dir.as_posix().split('/')[-1]
    print(p_str)
    #if p_str not in patients_to_retry: continue
    #try:
    #if p_str in patients_to_drop:
    #    print('skip ', p_str)
    #    continue
    patient_patch_path = patch_path.joinpath(p_str)
    patient_patch_path.mkdir(exist_ok=True, parents=True)
    physical_locations[p_str] = {}
    patch_size = np.array([80,80,80])
    for m in p_dir.glob('*.nii.gz'):
        print('-----------------')
        m_str = m.as_posix().split('/')[-1]
        if 'image' in m_str: continue
        #try:
        image = sitk.ReadImage(p_dir.joinpath('image.nii.gz').as_posix())
        mask = sitk.ReadImage(m.as_posix())
        print(m_str)
        #crop the image to patch_size around the tumor center
        tumour_center, center_location = find_centroid(mask, p_str) # center of GTV
        size = patch_size
        min_coords = np.floor(tumour_center - size / 2).astype(np.int64)
        max_coords = np.floor(tumour_center + size / 2).astype(np.int64)
        min_x, min_y, min_z = min_coords
        max_x, max_y, max_z = max_coords
        (img_x, img_y, img_z)=image.GetSize()
        min_x, max_x, min_pad_x, max_pad_x = tune_range(min_x, max_x, img_x, size[0], p_str) 
        min_y, max_y, min_pad_y, max_pad_y = tune_range(min_y, max_y, img_y, size[1], p_str) 
        min_z, max_z, min_pad_z, max_pad_z = tune_range(min_z, max_z, img_z, size[2], p_str) 

        min_pad = int(max([min_pad_x, min_pad_y, min_pad_z]))
        max_pad = int(max([max_pad_x, max_pad_y, max_pad_z]))
        lpad = list([min_pad_x, min_pad_y, min_pad_z])
        upad = list([max_pad_x, max_pad_y, max_pad_z])
        #print(m_str)
        #print(lpad)
        #print(upad)
        print(image.GetSize())
        print(min_coords, max_coords)
        print(min_pad, max_pad)
        image = image[min_x:max_x, min_y:max_y, min_z:max_z]
        # window image intensities to [-500, 1000] HU range
        image = sitk.Clamp(image, sitk.sitkFloat32, -500, 500)
        mask = mask[min_x:max_x, min_y:max_y, min_z:max_z]
        print(image.GetSize())
        image = sitk.ConstantPad(image, lpad, upad, 0.0)
        mask = sitk.ConstantPad(mask, lpad, upad, 0.0)
        print(image.GetSize())
        sitk.WriteImage(image, patient_patch_path.joinpath(f"image_{m_str.replace('Struct_','')}").as_posix())
        sitk.WriteImage(mask, patient_patch_path.joinpath(m_str).as_posix())
        physical_locations[p_str][m_str.replace('Struct_','').replace('.nii.gz','')] = center_location
        del(image)
        del(mask)
        #except:
        #    print(m)
        #    raise Exception('something went wrong...')
    
    #except:
    #    print(p_str)
        
with open(patch_path.joinpath('locations.pkl'), 'wb') as f:
    pickle.dump(physical_locations, f)
    f.close()

In [None]:
patient_patch_paths = patch_path.glob('*/')
tumor_locations = pd.read_pickle(location_pickle_path)
centered_locations = {}
no_gtvp = []
for pat in tqdm(patient_patch_paths):
    pat_str = pat.as_posix().split('/')[-1]
    if 'locations' in pat_str: continue
    if 'no_gtvp' in pat_str: continue
    print(pat_str)
    centered_locations[pat_str] = {}
    n_tumors = len(tumor_locations[pat_str])
    translation_factor = np.array([0., 0., 0.])
    if n_tumors == 1:
        if 'GTVp' in tumor_locations[pat_str].keys():
            centered_locations[pat_str]['GTVp'] = np.array([0., 0., 0.])
        else:
            centered_locations[pat_str][next(iter(tumor_locations[pat_str].keys()))] = np.array([0., 0., 0.])
            no_gtvp.append(pat_str)
        continue
    else:
        gtvs = tumor_locations[pat_str].keys()
        print(f"    {tumor_locations[pat_str].keys()}")
        if 'GTVp' in tumor_locations[pat_str].keys():
            translation_factor = tumor_locations[pat_str]['GTVp']
        else:
            no_gtvp.append(pat_str)
            print('    no GTVp, choosing highest GTVn in Z')
            array_locs = np.array([val for val in tumor_locations[pat_str].values()])
            origin_idx = np.where(array_locs == np.max(array_locs, axis=0)[2])[0][0]
            translation_factor = array_locs[origin_idx]
    for tumor in tumor_locations[pat_str]:
        centered_locations[pat_str][tumor.replace('.nii.gz','')] = tumor_locations[pat_str][tumor] - translation_factor

#with open(edge_path.joinpath('centered_locations_radcure_100324.pkl'), 'wb') as f:
#    pickle.dump(centered_locations, f)
#    f.close()

In [None]:
patient_graphs = {}
edge_dict = {}
for pat in centered_locations.keys():
    patient_graphs[pat] = nx.DiGraph(directed=True)
    edges_for_nx = []
    nodes = list(centered_locations[pat].keys())
    node_pos = list(centered_locations[pat].values())
    n_nodes = len(nodes)
    if n_nodes < 2:
        edges_for_nx.extend([(nodes[0], nodes[0])])
    else:
        n_neighbors = n_nodes-1 if n_nodes <= 3 else 3
        edge_list = sklearn.neighbors.kneighbors_graph(node_pos, n_neighbors).toarray()
        for node_idx, node_name in enumerate(nodes):
            #edges_for_nx.extend([(nodes[node_idx], nodes[jdx]) for jdx in range(len(edge_list[node_idx])) if edge_list[node_idx][jdx]])
            edges_for_nx.extend([(nodes[node_idx], nodes[jdx]) for jdx in range(len(edge_list[node_idx]))])

    patient_graphs[pat].add_edges_from(edges_for_nx)

        

In [None]:
patient_graphs['RADCURE-0006'].edges

In [None]:
with open(edge_path.joinpath('proto_complete_graphs_100424.pkl'), 'wb') as f:
    pickle.dump(patient_graphs, f)
    f.close()

In [None]:
rad_dict = {}
for pat in tqdm(list(nii_path.glob('*'))):
    pat_str = pat.as_posix().split('/')[-1]
    for m in pat.glob('*.nii.gz'):
        if 'image' in str(m):
            continue
        m_str = m.as_posix().split('/')[-1].strip('.nii.gz').strip('Struct_')
        key_name = f"{pat_str}__{m_str}"
        rad_dict[key_name] = {}
        rad_dict[key_name]['Image'] = m.as_posix().replace(m.as_posix().split('/')[-1], 'image.nii.gz')
        rad_dict[key_name]['Mask'] = m.as_posix()

In [None]:
rad_df = pd.DataFrame.from_dict(rad_dict, orient='index')

In [None]:
print(list(range(0,16000, 1000)))

In [None]:
rad_df.iloc[15000:16000]

In [None]:
data_path

In [None]:
import subprocess

for idx in range(8000, 9000, 1000):
    rad_df.iloc[idx:idx+1000].to_csv(data_path.joinpath('proto_radiomics.csv'))
    command = [
        "pyradiomics",
        data_path.joinpath('proto_radiomics.csv').as_posix(),
        "-o", data_path.joinpath(f"radiomics_part_{idx}.csv").as_posix(),
        "-f", "csv",
        "--param", './hnc_project/radiomics/pyradiomics_param.yaml',
    ]
    subprocess.run(command)

In [None]:
rad_df.index

In [None]:
radiomics.setVerbosity(20)
extractor = featureextractor.RadiomicsFeatureExtractor()
extractor.enableImageTypeByName('Wavelet')
print(extractor.settings)
print(extractor.enabledImagetypes)
print(extractor.enabledFeatures)

In [None]:
patient_patch_paths = patch_path.glob('*/')
for pat in patient_patch_paths:
    pat_str = pat.as_posix().split('/')[-1]
    print(pat_str)

    patches = pat.glob('image*.nii.gz')
    features_to_keep = {}
    for p in patches:
        p_name = p.as_posix().split('_')[-1].replace('.nii.gz','')
        print(f"    {p_name}")
        image = p.as_posix()
        mask = p.as_posix().replace('image', 'Struct')
        features = extractor.execute(image, mask)
        features_to_keep[p_name] = {key: value for key, value in features.items() if key.startswith('original')}
        
    with open(radiomics_path.joinpath(f"features_{pat_str}.pkl"), 'wb') as f:
        pickle.dump(features_to_keep, f)        
        f.close()
      
 