## Scan Augmentation

This script handles the augmentation of CT scans and segmentation masks. As a prerequisite, the user must provide the ground truth CT scan files as well as the corresponding ground truth segmentation masks.

Please note that the function find_pairs() AND augment_nifti() NEED to be edited BEFORE RUNNING THIS SCRIPT to adapt the corresponding naming convention.

In [None]:
import torchio as tio
import os
from tqdm import tqdm
import tkinter as tk
from tkinter import filedialog
import numpy as np
import pandas as pd
import datetime
import logging
import random
import torch
import sys
import traceback

Set global variables

In [None]:
image_records = []
list_of_ids = []
inversion_list =[]
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

#### pick_files():
The following methodopens a series of file dialogs for the user to select directories and returns the selected paths.

The function prompts the user to select the following directories:
1. The folder containing the CT scan niftis.
2. The folder containing the segmentation masks of the CT-scan niftis.
3. The output directory where processed files will be saved.
4. The directory for storing the logging file, the augmentation metadata .xlsx file and the inversion metadata .pth file.

Returns:
    tuple: A tuple containing four paths (CT_dir_path, SM_dir_path, output_path, logging_output_path).

In [None]:
def pick_files():
    root = tk.Tk()
    root.withdraw()  # Hide the main tkinter window
    CT_dir_path = filedialog.askdirectory(title="Select the folder with the CT SCAN NIFTIS") 
    print(f"Selected folder: {CT_dir_path}")
    SM_dir_path = filedialog.askdirectory(title="Select the folder with the SEGMENTATION MASKS of THE CT-SCAN NIFTIS")
    print(f"Selected folder: {SM_dir_path}")
    output_path = filedialog.askdirectory(title="Select the OUTPUT DIRECTORY")
    print(f"Selected folder: {output_path}")
    logging_output_path = filedialog.askdirectory(title="Select the OUTPUT DIRECTORY for the .LOG, .XLSX and .PTH FILES")
    print(f"Selected folder: {logging_output_path}")
    return CT_dir_path, SM_dir_path, output_path, logging_output_path

#### get_nifti_file_paths():
Retrieves the file paths of NIfTI files in the specified directory.

This function searches for files with '.nii' or '.nii.gz' extensions
in the given directory and returns their full paths.

Args:
    directory_path (str): The path to the directory containing NIfTI files.

Returns:
    list of str: A list of full file paths for each NIfTI file found in the directory.

In [None]:
def get_nifti_file_paths(directory_path):
    return [os.path.join(directory_path, f) for f in os.listdir(directory_path) if f.endswith('.nii') or f.endswith('.nii.gz')]


#### adjust_hyperparams_for_lvl2():
Adjusts hyperparameters for RandomElasticDeformation (Level 2 transformation) based on the image size and spacing to avoid potential folding during elastic deformation.

Args:
    image (torchio.Image): A torchio Image instance.
    num_control_points (int or tuple or list): The number of control points for the deformation grid.

Returns:
    list of float: A list of three floats representing the maximum displacement for each dimension (x, y, z).

In [None]:
def adjust_hyperparams_for_lvl2(image, num_control_points):
    image = image.as_sitk()
    img_size = image.GetSize()
    img_spacing = image.GetSpacing()
    bounds = np.array(img_size) * np.array(img_spacing)
    grid_spacing = bounds / (num_control_points - 2)
    potential_folding = grid_spacing / 2
    max_displacement = [0, 0, 0]
    max_displacement[0] = potential_folding[0] / 2
    max_displacement[1] = potential_folding[1] / 2
    return max_displacement

#### augment_nifti():

Applies a transformation to a CT and a corresponding SM and saves the transformed files. Details of the transformation are noted in the paper.

The transformation to be applied is determined by the 'lvl' parameter.
Level 1: RandomAffine with scales=(1), degrees=45 (i.e. only rotation)
Level 2: RandomElasticDeformation with num_control_points=(50,50,70) and max_displacement as calculated by adjust_hyperparams_for_lvl2
Level 3: Composition of Level 2 and Level 1

Args:
    CT_path (str): The file path of the CT scan NIfTI file.
    SM_path (str): The file path of the corresponding segmentation mask NIfTI file.
    output_path (str): The directory where the transformed files will be saved.
    lvl (str): The level of transformation to be applied (1, 2, or 3).

Returns:
    None

In [None]:
def augment_nifti(CT_path, SM_path, output_path, lvl):
    #!---------------- CHANGE NAME BASED ON GIVEN NAMING CONVENTION OF NIFTI FILES -----------------!#
    CT_file_name = os.path.splitext(os.path.basename(CT_path))[0].split('.nii')[0]
    SM_file_name = os.path.splitext(os.path.basename(SM_path))[0].split('_combined')[0]
        
    CT_subject = tio.Subject(
        image=tio.ScalarImage(CT_path)
    )

    SM_subject = tio.Subject(
        image=tio.LabelMap(SM_path)
    )

    transformation = 'None'
    if lvl == "1":
        transformation = tio.RandomAffine(scales=(1), degrees=45)

    if lvl == "2":
        num_control_points = (50,50,70)
        max_displacement = adjust_hyperparams_for_lvl2(CT_subject['image'], np.array(num_control_points))
        transformation = tio.RandomElasticDeformation(num_control_points=num_control_points, max_displacement=max_displacement, locked_borders=2, image_interpolation='nearest') # TODO: should be bspline

    if lvl == "3":
        num_control_points = (50,50,70)
        max_displacement = adjust_hyperparams_for_lvl2(CT_subject['image'], np.array(num_control_points))
        transform = tio.Compose([
            tio.RandomElasticDeformation(num_control_points=num_control_points, max_displacement=max_displacement, locked_borders=2, image_interpolation='nearest'),
            tio.RandomAffine(scales=(1), degrees=45),
            ])
        transformation = transform
    
    # Apply the transformation
    transformed_CT_subject = transformation(CT_subject)

    # Get history
    SM_transform = transformed_CT_subject.get_composed_history()

    # Apply the transformation as described in the transformation history to the SM subject
    transformed_SM_subject = SM_transform(SM_subject)

    # Get the inverse transformation
    inverse_transform = transformed_SM_subject.get_inverse_transform()

    # Save the transformed image 
    CT_file_output_path = get_unique_file_path(f"{output_path}/CT_{CT_file_name}_lvl_{lvl}_AUGMENTED.nii.gz")
    SM_file_output_path = get_unique_file_path(f"{output_path}/SM_{SM_file_name}_lvl_{lvl}_AUGMENTED.nii.gz")
    transformed_CT_subject['image'].save(CT_file_output_path)
    transformed_SM_subject['image'].save(SM_file_output_path)

    # Save metadata for the transformed image
    list_of_ids.append(CT_file_name)
    image_records.append({
        'Input CT Path': CT_path,
        'Input SM Path': SM_path,
        'Level': lvl,
        'Transformation Applied': SM_transform,
        'Output CT Path': CT_file_output_path,
        'Output SM Path': SM_file_output_path
    })
    
    # Save the inverse transformation information
    inversion_list.append((SM_file_output_path, inverse_transform))

#### get_unique_file_path():

Returns a unique file path if the given file path already exists.

If the given file path does not exist, it is returned as is.
If the file exists, a counter is added to the file name until a unique file is found.

Parameters:

    file_path : str
        The file path to check

Returns:

    str
        The unique file path

In [None]:
def get_unique_file_path(file_path):
    """
    """
    if not os.path.exists(file_path):
        return file_path

    # If the file exists, modify the name to make it unique
    base, _ = os.path.splitext(file_path)
    counter = 1

    # Add a counter to the file name until a unique file is found
    while os.path.exists(file_path):
        file_path = f"{base}_({counter}).nii.gz"
        counter += 1

    return file_path

#### save_records_to_excel():

Saves the provided records to an Excel file at the specified output path.

This function converts the given records into a pandas DataFrame and
saves it as an Excel file named 'augmentation_records_<timestamp>.xlsx'
in the specified output directory.

Args:

    records (list of dict): The records to be saved, where each record
                            is a dictionary with the same keys.
    output_excel_path (str): The directory path where the Excel file
                             will be saved.

Returns:

    None

In [None]:
def save_records_to_excel(records, output_excel_path):
    df = pd.DataFrame(records)
    output_excel_path = os.path.join(output_excel_path, f'augmentation_records_{timestamp}.xlsx')
    # Save to Excel
    df.to_excel(output_excel_path, index=False)
    print(f"Records saved to {output_excel_path}")

#### find_pairs():
Finds matching pairs of paths between two lists based on file_names.

Given two lists of paths, this function matches the paths by their file_names (ignoring the directory and file extension).
!----- THIS NEEDS TO BE EDITED BASED ON THE GIVEN NAMING CONVENTION OF THE NIFTI FILES. ----------!

Args:
    CT_paths (list of str): List of paths to CT scan NIfTI files.
    SM_paths (list of str): List of paths to segmentation mask NIfTI files.

Returns:
    list of tuple: A list of tuples, where each tuple contains a pair of
                   matched paths. The first element of each tuple is a
                   path from CT_paths and the second element is a path
                   from SM_paths.

In [None]:
def find_pairs(CT_paths, SM_paths):
    matched_tuples = []
    
    # Extract file_names from the paths
    #!---------------- CHANGE NAME BASED ON GIVEN NAMING CONVENTION OF NIFTI FILES -----------------!#
    CT_paths = {
    os.path.splitext(os.path.basename(CT_path))[0].split('.nii')[0]: CT_path 
    for CT_path in CT_paths
    }
    SM_paths = {
    os.path.splitext(os.path.basename(SM_path))[0].split('_combined')[0]: SM_path 
    for SM_path in SM_paths
    }

    # Iterate through the file_names in the first list
    for file_name, path in CT_paths.items():
        # If the same file_name exists in the second list, create a tuple
        for file_name2, path2 in SM_paths.items():
            if file_name in file_name2:
                matched_tuples.append((path, SM_paths[file_name]))
    
    return matched_tuples

### Running the different methods

In [None]:
# Pick files and set logging parameters
CT_dir_path, SM_dir_path, output_path, logging_output_path = pick_files()
excel_output_path = logging_output_path
logging_output_path= str(os.path.join(logging_output_path, f'augment_nifti_error.log'))
logging.basicConfig(
    filename=logging_output_path,
    level=logging.ERROR,
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
# Get and match paths
CT_paths = get_nifti_file_paths(CT_dir_path)
SM_paths = get_nifti_file_paths(SM_dir_path)
CT_SM_pair = find_pairs(CT_paths, SM_paths)
# For each pair of paths, transform the NIfTI files 10 times, each time choosing randomly between lvl 1, 2 or 3
for CT_path, SM_path in tqdm(CT_SM_pair, desc="Transforming nifti pairs", total=len(CT_paths)):
    for i in range(10):
        try:
            lvl_transform = random.choice([1, 2, 3])
            augment_nifti(CT_path, SM_path, output_path, str(lvl_transform))
        except Exception as e:
            logging.error(f"Error transforming {CT_path} and {SM_path} for level {lvl_transform} in line {sys.exc_info()[2].tb_lineno}: {e}, continuing with next pair...")
            logging.error(traceback.format_exc())
            continue
try:
    # Save accumulated transformation records to excel
    save_records_to_excel(image_records, excel_output_path)
except Exception as e:
    logging.error(f"Error saving records to Excel in line {sys.exc_info()[2].tb_lineno}: {e}")
    logging.error(traceback.format_exc())
try:
    # Save accumulated inversion metadata in a .pth file for further use in the inversion
    torch.save(inversion_list, os.path.join(output_path, f'inversion_list_{timestamp}.pth'))
except Exception as e:
    logging.error(f"Error saving inversion list in line {sys.exc_info()[2].tb_lineno}: {e}")
    logging.error(traceback.format_exc())