In [None]:
import os
from platipy.imaging import ImageVisualiser
import SimpleITK as sitk
import matplotlib.pyplot as plt
import skimage
import numpy as np
import nibabel as nib
import shutil
import itk
import json 
from scipy import ndimage
import feret
import csv
from PIL import Image, ImageEnhance, ImageOps
import random


In [None]:
output_dir = 'data/'

extracted_tumors = output_dir + 'extracted_tumor/'

cropped_tumors = output_dir + 'cropped_tumors/'

interpolated_tumors = output_dir + 'interpolated_tumors/'

input_dir_pre = ['data/pre/']

In [None]:
class Helper:
    
    @staticmethod
    def make_dir(directory_path, debug = False):
        if not os.path.exists(directory_path):
            os.makedirs(directory_path)
            if debug:
                print(f"Directory '{directory_path}' created.")
        elif debug:
            print(f"Directory '{directory_path}' already exists.")
            
    @staticmethod
    def is_nifti_file(file):
        return file.endswith(".nii.gz")
        
    @staticmethod
    def extract_label(image, label_value):
        label_mask = np.where(image == label_value, 1, 0)
        return label_mask

        
    @staticmethod
    def get_base_name_from_nii_gz(file_path):
        # Get the base name from the file .nii.gz path
        return os.path.splitext(os.path.splitext(file_path)[0])[0]
            
    @staticmethod
    def save_new_nifti(sitk_image, reference_image, output_file):
        
        # Set the spacing, origin, and direction from the reference image
        sitk_image.SetSpacing(reference_image.GetSpacing())
        sitk_image.SetOrigin(reference_image.GetOrigin())
        sitk_image.SetDirection(reference_image.GetDirection())

        # Save the new NIfTI file
        sitk.WriteImage(sitk_image, output_file)

# Extract tumor from liver segmentations

In [None]:
def extract_tumor(tumor_label):
    Helper.make_dir(extracted_tumors)
    for input_dir in input_dir_pre:
        images = os.listdir(input_dir)
        for image in images:
            if not Helper.is_nifti_file(image):
                continue

            image_path = input_dir + image
            image_obj = sitk.ReadImage(image_path)
            image_array = sitk.GetArrayFromImage(image_obj)
            label_slices = [Helper.extract_label(image_array[i, :, :], tumor_label) for i in range(image_array.shape[0])]

            # Create a new Nifti image with only the label slices
            label_data = np.stack(label_slices, axis=-1)
            # Save the new Nifti file
            path = extracted_tumors + image

            sitk_image = sitk.GetImageFromArray(label_data)
            Helper.save_new_nifti(sitk_image, image_obj, path)


# Find Maximal diameters


In [None]:
def max_axial_diam(img_path):
    img_obj = sitk.ReadImage(img_path, imageIO="NiftiImageIO", outputPixelType = sitk.sitkInt64)
    img_array = sitk.GetArrayFromImage(img_obj)
    #orig_spacing = orig_img.GetSpacing()
    max = -1
    max_slice = -1
   # print(img_array.shape)
    for slice in range(img_array.shape[0]):
        slice_mask = img_array[slice,:,:]
            # no tumor tissue
        if (slice_mask == 0).all():
            continue
        max_feret = feret.max(slice_mask, edge = True)
        if max_feret > max:
            max = max_feret
            max_slice = slice
    if max == -1: max = 0
    return max, max_slice
    
def find_max_axial_diams(dir):
    patients = os.listdir(dir)
    diameters = {}
    for patient in patients:
        if not Helper.is_nifti_file(patient): continue
        patient_dir = dir + patient
        max_diam, max_slice = max_axial_diam(patient_dir)
        diameters[patient] = [max_diam, max_slice]
    return diameters

# Centre and crop

In [None]:
def find_crop_dim():
    maxHeight = 0
    maxWidth = 0    
    patients = os.listdir(extracted_tumors)
    
    for patient in patients:
        patient_dir = extracted_tumors + patient
        if not Helper.is_nifti_file(patient_dir):
            continue

        image_path = patient_dir 
        image_obj = nib.load(image_path)
        image_array = image_obj.get_fdata() 
        max_tumor_mask = np.sum(image_array, axis=2)
        # Sum the rows to get max width
        max_tumor_cols = np.sum(max_tumor_mask, axis=1)
        
        # Find the non-zero elements in the row (edges of largest tumor)
        non_zero_cols = np.nonzero(max_tumor_cols)[0]
        # Find the distance between the left and right edges of the tumor
        temp_width = non_zero_cols[-1] - non_zero_cols[0]

        # Sum the rows to get max height
        max_tumor_rows = np.sum(max_tumor_mask, axis=0)
        
        # Find the non-zero elements in the column (edges of largest tumor)
        non_zero_rows = np.nonzero(max_tumor_rows)[0]
        # Find the distance between the upper and lower edges of the tumor
        temp_height = non_zero_rows[-1] - non_zero_rows[0]

        # Check if width and height are larger than existing
        if temp_width > maxWidth:
            maxWidth = temp_width
        if temp_height > maxHeight:
            maxHeight = temp_height
            
        
    # Want to crop in a square, so take maximum of max height and width
    crop_dim = max(maxHeight, maxWidth)
    return crop_dim



In [None]:
def crop(crop_dim):
    patients = os.listdir(extracted_tumors)
    
    Helper.make_dir(cropped_tumors)
    for patient in patients:
        patient_dir = extracted_tumors + patient
        
        if not Helper.is_nifti_file(patient):
            continue
            
        image_path = patient_dir # + '/' + tumor
        img_obj = sitk.ReadImage(image_path, imageIO="NiftiImageIO", outputPixelType = sitk.sitkInt64)
        image_array = sitk.GetArrayFromImage(img_obj)
        
        num_slices = image_array.shape[2]
        cropped_slices = np.zeros((crop_dim, crop_dim, num_slices))
        nRows, nCols, nSlice = image_array.shape
        img = []
        for currSlice in range(nSlice):
            slice_mask = image_array[:,:,currSlice]
            if (slice_mask == 0).all():
                continue
            slice_cols = np.sum(slice_mask, axis=0)
            tumor_cols = np.nonzero(slice_cols)[0]
            tumor_width = tumor_cols[-1] - tumor_cols[0]
            ctr_x = tumor_cols[0] + tumor_width // 2

            slice_rows = np.sum(slice_mask, axis=1)
            tumor_rows = np.nonzero(slice_rows)[0]
            tumor_height = tumor_rows[-1] - tumor_rows[0]
            ctr_y = tumor_rows[0] + tumor_height // 2

            startCol = ctr_x - crop_dim // 2
            startRow = ctr_y - crop_dim // 2

            if startRow <= 0:
                startRow = 1
            if startCol <= 0:
                startCol = 1
            if startRow + crop_dim > nRows:
                startRow = nRows - crop_dim
            if startCol + crop_dim > nCols:
                startCol = nCols - crop_dim

            image_cr = image_array[startRow : startRow + crop_dim, startCol : startCol + crop_dim, currSlice].copy()

            colormap = np.array([0, 127, 255], dtype=np.uint8)  # Define the colormap for grayscale intensities
            imageCr = image_cr.astype(np.uint8)
            # Map the values in the mask array to grayscale intensities based on the colormap
            mapped_array = colormap[imageCr]

            # Create an image from the mapped array
            mask_image = Image.fromarray(mapped_array, mode='L')
            img.append(imageCr)
            
        img.reverse()
        sitk_image = sitk.GetImageFromArray(img)
        i = sitk.ReadImage(image_path)
        a = i.GetSpacing()
       
        sitk_image.SetSpacing(a)
        #tumor_plot_dir = plot_dir + '/' + tumor
        sitk.WriteImage(sitk_image, cropped_tumors+patient)



# Interpolation

In [None]:
def linear_interpolation(target_voxel_spacing=(1.0, 1.0, 1.0)):
    Helper.make_dir(interpolated_tumors)
    patients = os.listdir(cropped_tumors)
    for patient in patients:
        patient_dir = cropped_tumors + patient
        output_patient_dir = interpolated_tumors #+ patient
        Helper.make_dir(output_patient_dir)

        if not Helper.is_nifti_file(patient):
            continue
        
        tumor_dir = patient_dir #+ '/' + tumor
        image_obj = sitk.ReadImage(tumor_dir)
        image_array = sitk.GetArrayFromImage(image_obj)
        voxel_spacing = image_obj.GetSpacing()
        
        # Perform the linear interpolation
        interpolated_scan = ndimage.zoom(image_array, target_voxel_spacing, order=1)
        interplated_image = sitk.GetImageFromArray(interpolated_scan)
        interplated_image.SetSpacing(target_voxel_spacing)
        file = output_patient_dir + patient
        sitk.WriteImage(interplated_image, file)
        

# Run steps

In [None]:
tumor_label = 2
extract_tumor(tumor_label)

In [None]:
max_diams = find_max_axial_diams(extracted_tumors)

In [None]:
# remove patients that have tumors that are too small
for patient, values in max_diams.items():
    if values[0] < 15:
        os.remove(extracted_tumors + patient)


In [None]:
crop_dim = find_crop_dim()

In [None]:
crop(crop_dim)

In [None]:
linear_interpolation([1,1,1])

In [None]:
# save each slice as individual png
patients = os.listdir(interpolated_tumors)
final = output_dir + 'final/'
Helper.make_dir(final)
for patient in patients:
    if not Helper.is_nifti_file(patient): continue

    patient_dir = interpolated_tumors + patient
    image_obj = sitk.ReadImage(patient_dir)   
    image_array = sitk.GetArrayFromImage(image_obj)
    out = final + Helper.get_base_name_from_nii_gz(patient)[:-2] + '/'
    Helper.make_dir(out)
    slices = image_array.shape[0]
    for i in range(slices):
        slice = image_array[i, :, :]
        plt.figure()
        plt.axis('off')
        plt.imsave(out + 'slice_' + str(i) + '.png', slice, cmap='gray')
        plt.close()

In [None]:
# separate the maximum diamater slices 
dir = 'data/final/'
out = 'data/max_slices/'
Helper.make_dir(out)
for patient, vals in max_diams.items():
    patient_name = Helper.get_base_name_from_nii_gz(patient)[:-2]
    slice = vals[1]
    patient_dir = dir + patient_name
    image_name = 'slice_' + str(slice) + '.png'
    image_path = patient_dir + '/' + image_name
    dest = out + patient_name + '.png'
    if not os.path.exists(image_path): 
        continue
    shutil.copy(image_path, dest)


In [None]:
output_dir = 'data/'
extracted_tumors = output_dir + 'extracted_tumor_post/'
input_dirs = ['data/post/']
#extract_tumor(2)
# get maximum axial diameters in post-treatment scans for labeling
max_diams_post = find_max_axial_diams(extracted_tumors)

In [None]:
pre_diams = {}
post_diams = {}
for patient, vals in max_diams.items():
    if patient == 'RIA_17-010C_000_000188_0.nii.gz': continue
    diam = vals[0]
    if diam < 15:
        continue
    patient_name = Helper.get_base_name_from_nii_gz(patient)[:-2]
    pre_diams[patient_name] = diam
    post_diams[patient_name] = max_diams_post[patient_name + '_1.nii.gz'][0]

In [None]:
def get_diam_change_total(pre_diams, post_diams):
    diam_change = {}
    for patient, pre_diam in pre_diams.items():
        if pre_diam == 0: continue
        post_diam = post_diams[patient]
        x = 100*((post_diam - pre_diam)/pre_diam)
        diam_change[patient] = x
        
    return diam_change
    
def get_recist_labels(diam_change):
    recist = {}
    for patient, percantage in diam_change.items():
        print(percantage)
        if percantage < 20 and percantage > -30:
            recist[patient] ='Stable Disease'
        elif percantage >= 20:
            recist[patient] = 'Progressive Disease'
        elif percantage <= -30 and percantage != -100:
            recist[patient] = 'Partial Response'
        else:
            recist[patient] = 'Complete Response'
            
    return recist

In [None]:
diam_change = get_diam_change_total(pre_diams, post_diams)

In [None]:
recist = get_recist_labels(diam_change)

In [None]:
file_path = 'output.csv'
data = []
for patient, label in recist.items():
    content = ''
    if label == 'Stable Disease':
        content = [patient+ '.png', 0, 0, 1, 0]
    if label == 'Partial Response':
        content = [patient+ '.png', 0, 1, 0, 0]
    if label == 'Progressive Disease':
        content = [patient+ '.png', 0, 0, 0, 1]
    if label == 'Complete Response':
        content = [patient + '.png', 1, 0, 0, 0]
    data.append(content)        


with open(file_path, 'w', newline='') as csv_file:
    csv_writer = csv.writer(csv_file, delimiter=' ')
    csv_writer.writerows(data)


In [None]:
def flip_image(input_path, output_path, flip_horizontal=True, flip_vertical=False):
    image = Image.open(input_path)
    if flip_horizontal:
        image = image.transpose(Image.FLIP_LEFT_RIGHT)
    if flip_vertical:
        image = image.transpose(Image.FLIP_TOP_BOTTOM)
    image.save(output_path)
    
def zoom_image(input_path, output_path, zoom_factor):
    image = Image.open(input_path)
    width, height = image.size
    new_width = int(width * zoom_factor)
    new_height = int(height * zoom_factor)
    image = image.resize((new_width, new_height))
    image.save(output_path)

new_content = []
with open('output.csv', 'r') as csv_file:
    reader = csv.reader(csv_file)
    for row in reader:
        words = row[0].split()
        if words[3] == '1':
            input_path = 'data/max_slices/' + words[0]
            patient = os.path.splitext(words[0])[0]
            flip_horizontal = random.choice([True, False])
            flip_vertical = random.choice([True, False])
            if not flip_horizontal and not flip_vertical:
                flip_horizontal = True
            output_path = 'data/max_slices/' + patient + '_flipped.png'
            flip_image(input_path, output_path, flip_horizontal, flip_vertical)
            new_content.append([patient + '_flipped.png', 0, 0, 0, 1])

            zoom_factor = random.uniform(0.5, 1.2)
            output_path = 'data/max_slices/' + patient + '_zoomed.png'
            zoom_image(input_path, output_path, zoom_factor)
            new_content.append([patient + '_zoomed.png', 0, 0, 0, 1])

    
    with open('output.csv', 'a', newline='') as csv_file:
        writer = csv.writer(csv_file, delimiter=' ')
        
        # Write each row to the CSV file
        for line in new_content:
            writer.writerow(line)

In [None]:
with open('output.csv', 'r') as csv_file:
    reader = csv.reader(csv_file)
    data = list(reader)

    # Shuffle the list
    random.shuffle(data)

    # Write shuffled list back to CSV
    with open('data.csv', 'w') as csv_file:
        writer = csv.writer(csv_file, delimiter=' ')
        writer.writerows(data)


In [None]:
# separate sets
csv_file_path = 'data.csv'
output_file_80 = 'train_list_crlm.txt'
output_file_10_1 = 'test_list_crlm.txt'
output_file_10_2 = 'val_list_crlm.txt'

# Open the CSV file and create a CSV reader object
with open(csv_file_path, 'r') as file:
    csv_reader = csv.reader(file)

    # Skip the header row
  #  header = next(csv_reader)

    # Count the total number of rows
    total_rows = sum(1 for row in csv_reader)

# Calculate the number of rows for each category
rows_80 = int(0.8 * total_rows)
rows_10 = int(0.1 * total_rows)

# Open the CSV file again and create CSV writer objects for each output file
with open(csv_file_path, 'r') as file, \
    open(output_file_80, 'w') as file_80, \
    open(output_file_10_1, 'w') as file_10_1, \
    open(output_file_10_2, 'w') as file_10_2:
    
    csv_reader = csv.reader(file)
    
    # Iterate over the rows and write to the appropriate output file
    for i, row in enumerate(csv_reader):
        if i < rows_80:
            file_80.write('\t'.join(row) + '\n')
        elif i < rows_80 + rows_10:
            file_10_1.write('\t'.join(row) + '\n')
        elif i < rows_80 + 2 * rows_10:
            file_10_2.write('\t'.join(row) + '\n')
        else:
            break