# CAIRO5 data preparation

In [None]:
#imports
import nibabel as nib
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

from monai.transforms import (
    Resize,
    ScaleIntensityRange,
    Compose
)
import shutil

from scipy.ndimage import binary_dilation
from scipy.ndimage import generate_binary_structure

import pandas as pd


# Set numpy print options to avoid truncation
#np.set_printoptions(threshold=np.inf)

# Training (CAIRO5)

In [None]:
all_scans_orig_path = None
all_scans_path = None

for scan in os.listdir(all_scans_orig_path):
        print("Currently processing: ", scan)
        dest_name = scan.replace("_0000", "")
        #shutil.copy(os.path.join(all_scans_orig_path, scan), os.path.join(all_scans_path, dest_name))

### Define all useful paths

In [None]:
# All paths 
all_scans_path = None
all_segmentations_path = None

teacher_data_path = None

paired_scans_path = None
paired_segmentations_path = None

# Clinical data file
clinical_data_path = None


### Delete scans that are in teacher

In [None]:
"""
Deletes scans from all_scans that are also in teacher_data
"""

# Get the list of scans in test_data and teacher_data
teacher_data_scans = set(os.listdir(teacher_data_path))
num_of_deleted = 0

# Iterate through all scans in all_scans_path
for scan in os.listdir(all_scans_path):
    # Check if the scan exists in either test_data or teacher_data
    if scan in teacher_data_scans:
        # Delete the scan from all_scans_path
        os.remove(os.path.join(all_scans_path, scan))
        print(f"Deleted scan: {scan}")
        num_of_deleted += 1

print(f"Number of deleted scans: {num_of_deleted}")

### Delete all scans where scan_type =>2

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    scan_type = scan.split("_")[1][0]
    if int(scan_type) >= 2:
        scans_to_delete.append(scan)


for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

segmentations_to_delete = []

for segm in os.listdir(all_segmentations_path):
    segm_type = segm.split("_")[1][0]
    if int(segm_type) >= 2:
        segmentations_to_delete.append(segm)


for segm in segmentations_to_delete:
    print(f"Deleting segmentation: {segm}")
    os.remove(os.path.join(all_segmentations_path, segm))


### Delete scans without corresponding segmentation

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    if scan not in os.listdir(all_segmentations_path):
        scans_to_delete.append(scan)

for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

### Subset 0 & 1 scans

In [None]:
# Keep only scans where both 0 and 1 scan are present
scans_in_folder = set(os.listdir(all_scans_path))
scan_ids = set(scan.split("_")[0] for scan in scans_in_folder)

scans_to_keep = set()
for scan_id in scan_ids:
    scan_0 = f"{scan_id}_0.nii.gz"
    scan_1 = f"{scan_id}_1.nii.gz"
    if scan_0 in scans_in_folder and scan_1 in scans_in_folder:
        scans_to_keep.add(scan_0)
        scans_to_keep.add(scan_1)

counter = 0
for scan in scans_in_folder:
    if scan not in scans_to_keep:
        print(f"Deleting scan: {scan}")
        counter += 1
        os.remove(os.path.join(all_scans_path, scan))

print(f"Number of scans deleted: {counter}")

### Delete scans of patients that are not in clinical file

In [None]:
clinical_data = pd.read_excel(clinical_data_path)
subject_keys = clinical_data["SubjectKey"].astype(int).tolist()


scans_to_delete = []
for scan in os.listdir(all_scans_path):
    scan_id = scan.split("_")[0][-3:]
    if int(scan_id) not in subject_keys:
        scans_to_delete.append(scan)

for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))
 


### Segment liver and bounding box

In [None]:
"""
Segments the liver and applies a bounding box 
"""

for scan in os.listdir(all_scans_path):
    if scan in os.listdir(paired_scans_path):
        print(f"skipping: {scan}, since it already exists")
        continue
    else:
        print(f"currently processing: {scan}")
        
        #load image and corresponding segmentation
        image = nib.load(os.path.join(all_scans_path, scan))
        segmentation = nib.load(os.path.join(all_segmentations_path, scan))

        image_data = image.get_fdata()
        segmentation_data = segmentation.get_fdata()

        liver_mask = (segmentation_data == 12) | (segmentation_data == 13)

        #apply mask to image
        liver_image = np.copy(image_data)
        liver_image[~liver_mask] = -1000

        # Find the indices of the liver mask
        mask_indices = np.argwhere(liver_mask)

        # Calculate the bounding box
        min_indices = mask_indices.min(axis=0)
        max_indices = mask_indices.max(axis=0)

        # Crop the liver image using the bounding box
        cropped_liver_image = liver_image[min_indices[0]:max_indices[0]+1, min_indices[1]:max_indices[1]+1, min_indices[2]:max_indices[2]+1]

        # Create a new NIfTI image
        new_image = nib.Nifti1Image(cropped_liver_image, affine=image.affine, header=image.header)

        # Save the new NIfTI image to a file with the original name
        output_file_path = os.path.join(paired_scans_path, scan)
        nib.save(new_image, output_file_path)

### Down-sized tumor mask data preparation

In [None]:
# Generates a 3D spherical-like connectivity structure
structure = generate_binary_structure(3, 1)  # 3D, with connectivity=1

for scan in os.listdir(paired_scans_path):
    if scan in os.listdir(paired_segmentations_path):
        #DOESNT WORK WITH .npy files
        print(f"skipping: {scan}, since it already exists")
        continue
    else:
        print(f"Currently processing: {scan}")
        segmentation = nib.load(os.path.join(all_segmentations_path, scan))
        segmentation_data = segmentation.get_fdata()

        # Create a liver mask (labels 12 and 13)
        liver_mask = (segmentation_data == 12) | (segmentation_data == 13)

        # Find the indices of the liver mask
        mask_indices = np.argwhere(liver_mask)

        # Calculate the bounding box for the liver
        min_indices = mask_indices.min(axis=0)
        max_indices = mask_indices.max(axis=0)

        # Crop the segmentation data using the bounding box
        cropped_segmentation_data = segmentation_data[min_indices[0]:max_indices[0]+1, min_indices[1]:max_indices[1]+1, min_indices[2]:max_indices[2]+1]

        # Create a tumor mask (label 13) within the cropped liver region
        tumor_mask = (cropped_segmentation_data == 13)

        # Apply binary dilation to the tumor mask
        dilated_tumor_mask = binary_dilation(tumor_mask, structure=structure, iterations=8)    
        dilated_tumor_mask = torch.tensor(dilated_tumor_mask).unsqueeze(0).unsqueeze(0).float()

        # Downsize the dilated tumor mask
        downsampled_tumor_mask = F.interpolate(dilated_tumor_mask, size=(128, 128, 32), mode="trilinear", align_corners=False)
        downsampled_tumor_mask = downsampled_tumor_mask.squeeze().numpy() 

        # Save the new NIfTI image
        output_file_path = os.path.join(paired_segmentations_path, scan[:-7] + ".npy")
        np.save(output_file_path, downsampled_tumor_mask)
    

In [None]:
# Check which are in paired_scans but not in all_scans_path
paired_scans = set(os.listdir(paired_scans_path))
all_scans = set(os.listdir(all_scans_path))
scans_not_in_all_scans = paired_scans - all_scans
print("Scans in paired_scans but not in all_scans_path:")
for scan in scans_not_in_all_scans:
    print(scan)

# Testing (CAIRO5 subset)

In [None]:
# Usefull paths
 
all_scans_path = None
all_segmentations_path = None

test_data_path = None

paired_scans_path = None
paired_segmentations_path = None

resized_paired_scans_path = None


### copy to correct folders

In [None]:
for scan in os.listdir(test_data_path):
    if scan.endswith(".nii.gz"):
        if scan.endswith("_0000.nii.gz"):
            shutil.copy(os.path.join(test_data_path, scan), os.path.join(all_scans_path, scan.replace("_0000", "")))
        else: 
            shutil.copy(os.path.join(test_data_path, scan), os.path.join(all_segmentations_path, scan))
    else:
        pass

### Delete all scans and segmentations types >= 2

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    scan_type = scan.split("_")[1][0]
    if int(scan_type) >= 2:
        scans_to_delete.append(scan)


for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

segmentations_to_delete = []

for segm in os.listdir(all_segmentations_path):
    segm_type = segm.split("_")[1][0]
    if int(segm_type) >= 2:
        segmentations_to_delete.append(segm)


for segm in segmentations_to_delete:
    print(f"Deleting segmentation: {segm}")
    os.remove(os.path.join(all_segmentations_path, segm))


### Delete scans without corresponding segmentation


In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    if scan not in os.listdir(all_segmentations_path):
        scans_to_delete.append(scan)

for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

### Subset 0 & 1 scans

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    scan_type = scan.split("_")[1][0]
    if int(scan_type) == 0:
        second_scan_name = scan.split("_")[0] + "_1.nii.gz"
        if second_scan_name not in os.listdir(all_scans_path):
            scans_to_delete.append(scan)
        else:
            pass
    else:
        pass

for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

### Read and write every file to resolve corrupted segmentations issue
Some segmentations are corrupted and not recognized as .nii.gz files. Opening and rewriting them with SimpleITK resolves this issue. 

In [None]:
import SimpleITK as sitk 

for segm in os.listdir(all_segmentations_path):
    print(f"Currently processing: {segm}")
    sitk_img = sitk.ReadImage(os.path.join(all_segmentations_path, segm))  # load with sitk
    sitk.WriteImage(sitk_img, os.path.join(all_segmentations_path, segm))  # overwrite
    

In [None]:
# Check if the gzip files are valid
import gzip

for filename in os.listdir(all_segmentations_path):
    try:
        with gzip.open(os.path.join(all_segmentations_path, filename), "rb") as f:
            f.read(10)
        print("Valid gzip")
    except Exception as e:
        print(filename)
        print(f"Invalid gzip file: {e}")


### Segment liver and apply bounding box

In [None]:
"""
Segments the liver and applies a bounding box 
"""

for scan in os.listdir(all_scans_path):
    if scan in os.listdir(paired_scans_path):
        print(f"skipping: {scan}, since it already exists")
        continue
    else:
        print(f"currently processing: {scan}")
        
        #load image and corresponding segmentation
        image = nib.load(os.path.join(all_scans_path, scan))
        segmentation = nib.load(os.path.join(all_segmentations_path, scan))

        image_data = image.get_fdata()
        segmentation_data = segmentation.get_fdata()

        liver_mask = (segmentation_data == 12) | (segmentation_data == 13)

        #apply mask to image
        liver_image = np.copy(image_data)
        liver_image[~liver_mask] = -1000

        # Find the indices of the liver mask
        mask_indices = np.argwhere(liver_mask)

        # Calculate the bounding box
        min_indices = mask_indices.min(axis=0)
        max_indices = mask_indices.max(axis=0)

        # Crop the liver image using the bounding box
        cropped_liver_image = liver_image[min_indices[0]:max_indices[0]+1, min_indices[1]:max_indices[1]+1, min_indices[2]:max_indices[2]+1]

        # Create a new NIfTI image
        new_image = nib.Nifti1Image(cropped_liver_image, affine=image.affine, header=image.header)

        # Save the new NIfTI image to a file with the original name
        output_file_path = os.path.join(paired_scans_path, scan)
        nib.save(new_image, output_file_path)

### Segmentation downsizing

In [None]:
# Generates a 3D spherical-like connectivity structure
structure = generate_binary_structure(3, 1)  # 3D, with connectivity=1

for scan in os.listdir(paired_scans_path):
    print(f"Currently processing: {scan}")
    segmentation = nib.load(os.path.join(all_segmentations_path, scan))
    segmentation_data = segmentation.get_fdata()

    # Create a liver mask (labels 12 and 13)
    liver_mask = (segmentation_data == 12) | (segmentation_data == 13)

    # Find the indices of the liver mask
    mask_indices = np.argwhere(liver_mask)


    # Calculate the bounding box for the liver
    min_indices = mask_indices.min(axis=0)
    max_indices = mask_indices.max(axis=0)

    # Crop the segmentation data using the bounding box
    cropped_segmentation_data = segmentation_data[min_indices[0]:max_indices[0]+1, min_indices[1]:max_indices[1]+1, min_indices[2]:max_indices[2]+1]

    # Create a tumor mask (label 13) within the cropped liver region
    tumor_mask = (cropped_segmentation_data == 13)

    # Apply binary dilation to the tumor mask
    dilated_tumor_mask = binary_dilation(tumor_mask, structure=structure, iterations=8)
    dilated_tumor_mask = torch.tensor(dilated_tumor_mask).unsqueeze(0).unsqueeze(0).float()


    # Downsize the dilated tumor mask
    downsampled_tumor_mask = F.interpolate(dilated_tumor_mask, size=(128, 128, 32), mode="trilinear", align_corners=False)
    downsampled_tumor_mask = downsampled_tumor_mask.squeeze().numpy() 

    # Create a new NIfTI image for the downsampled tumor mask
    #new_image = nib.Nifti1Image(dilated_tumor_mask.astype(np.uint8), affine=segmentation.affine, header=segmentation.header)

    # Save the new NIfTI image
    output_file_path = os.path.join(paired_segmentations_path, scan[:-7] + ".npy")
    np.save(output_file_path, downsampled_tumor_mask)

### Scan resizing
These scans are used to overlay with the saliency maps.

In [None]:
resize_scans_path = None


In [None]:
for scan in os.listdir(resize_scans_path):
    print(f"Currently processing: {scan}")
    image = nib.load(os.path.join(resize_scans_path, scan))
    image_data = image.get_fdata()

    # Add channel dimension to image data
    image_data = np.expand_dims(image_data, axis=0)

    transform = [
        ScaleIntensityRange(a_min=-100, a_max=200, b_min=0.0, b_max=1.0, clip=True),
        Resize((256, 256, 64), mode="trilinear")
    ]
    image_data = Compose(transform)(image_data)
    image_data = image_data.squeeze()

    nifti_image = nib.Nifti1Image(image_data, affine=image.affine, header= image.header)

    nib.save(nifti_image, os.path.join(resize_scans_path, scan))


# Testing (AMCore) 

In [None]:
# Usefull paths
 
all_scans_path = None
all_segmentations_path = None

test_data_path = None
paired_scans_path = None
paired_segmentations_path = None

resized_paired_scans_path = None

### copy to correct folders

In [None]:
for scan in os.listdir(test_data_path):
    if scan.endswith(".nii.gz"):
        if scan.endswith("_0000.nii.gz"):
            shutil.copy(os.path.join(test_data_path, scan), os.path.join(all_scans_path, scan.replace("_0000", "")))
        else: 
            shutil.copy(os.path.join(test_data_path, scan), os.path.join(all_segmentations_path, scan))
    else:
        pass

###  Delete scans =>2

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    scan_type = scan.split("_")[1][0]
    if int(scan_type) >= 2:
        scans_to_delete.append(scan)


for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

segmentations_to_delete = []

for segm in os.listdir(all_segmentations_path):
    segm_type = segm.split("_")[1][0]
    if int(segm_type) >= 2:
        segmentations_to_delete.append(segm)


for segm in segmentations_to_delete:
    print(f"Deleting segmentation: {segm}")
    os.remove(os.path.join(all_segmentations_path, segm))


### Delete scans without segmentation

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    if scan not in os.listdir(all_segmentations_path):
        scans_to_delete.append(scan)

for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

### Subset paired scans

In [None]:
scans_to_delete = []

for scan in os.listdir(all_scans_path):
    scan_type = scan.split("_")[1][0]
    if int(scan_type) == 0:
        second_scan_name = scan.split("_")[0] + "_1.nii.gz"
        if second_scan_name not in os.listdir(all_scans_path):
            scans_to_delete.append(scan)
        else:
            pass
    else:
        pass

for scan in scans_to_delete:
    print(f"Deleting scan: {scan}")
    os.remove(os.path.join(all_scans_path, scan))

### Segment liver and apply bounding box

In [None]:
"""
Segments the liver and applies a bounding box 
"""

for scan in os.listdir(all_scans_path):
    if scan in os.listdir(paired_scans_path):
        print(f"skipping: {scan}, since it already exists")
        continue
    else:
        print(f"currently processing: {scan}")
        
        #load image and corresponding segmentation
        image = nib.load(os.path.join(all_scans_path, scan))
        segmentation = nib.load(os.path.join(all_segmentations_path, scan))

        image_data = image.get_fdata()
        segmentation_data = segmentation.get_fdata()

        liver_mask = (segmentation_data == 12) | (segmentation_data == 13)

        #apply mask to image
        liver_image = np.copy(image_data)
        liver_image[~liver_mask] = -1000

        # Find the indices of the liver mask
        mask_indices = np.argwhere(liver_mask)

        # Calculate the bounding box
        min_indices = mask_indices.min(axis=0)
        max_indices = mask_indices.max(axis=0)

        # Crop the liver image using the bounding box
        cropped_liver_image = liver_image[min_indices[0]:max_indices[0]+1, min_indices[1]:max_indices[1]+1, min_indices[2]:max_indices[2]+1]

        # Create a new NIfTI image
        new_image = nib.Nifti1Image(cropped_liver_image, affine=image.affine, header=image.header)

        # Save the new NIfTI image to a file with the original name
        output_file_path = os.path.join(paired_scans_path, scan)
        nib.save(new_image, output_file_path)

In [None]:
# Generates a 3D spherical-like connectivity structure
structure = generate_binary_structure(3, 1)  # 3D, with connectivity=1

for scan in os.listdir(paired_scans_path):
    print(f"Currently processing: {scan}")
    segmentation = nib.load(os.path.join(all_segmentations_path, scan))
    segmentation_data = segmentation.get_fdata()

    # Create a liver mask (labels 12 and 13)
    liver_mask = (segmentation_data == 12) | (segmentation_data == 13)

    # Find the indices of the liver mask
    mask_indices = np.argwhere(liver_mask)


    # Calculate the bounding box for the liver
    min_indices = mask_indices.min(axis=0)
    max_indices = mask_indices.max(axis=0)

    # Crop the segmentation data using the bounding box
    cropped_segmentation_data = segmentation_data[min_indices[0]:max_indices[0]+1, min_indices[1]:max_indices[1]+1, min_indices[2]:max_indices[2]+1]

    # Create a tumor mask (label 13) within the cropped liver region
    tumor_mask = (cropped_segmentation_data == 13)

    # Apply binary dilation to the tumor mask
    dilated_tumor_mask = binary_dilation(tumor_mask, structure=structure, iterations=8)
    dilated_tumor_mask = torch.tensor(dilated_tumor_mask).unsqueeze(0).unsqueeze(0).float()


    # Downsize the dilated tumor mask
    downsampled_tumor_mask = F.interpolate(dilated_tumor_mask, size=(128, 128, 32), mode="trilinear", align_corners=False)
    downsampled_tumor_mask = downsampled_tumor_mask.squeeze().numpy() 

    # Create a new NIfTI image for the downsampled tumor mask
    #new_image = nib.Nifti1Image(dilated_tumor_mask.astype(np.uint8), affine=segmentation.affine, header=segmentation.header)

    # Save the new NIfTI image
    output_file_path = os.path.join(paired_segmentations_path, scan[:-7] + ".npy")
    np.save(output_file_path, downsampled_tumor_mask)

In [None]:
# Calculate and plot the average dimension of all scans in paired_scans_path

dims = []
for scan_file in os.listdir(paired_scans_path):
    img = nib.load(os.path.join(paired_scans_path, scan_file))
    dims.append(img.shape)

dims = np.array(dims)
avg_dims = np.mean(dims, axis=0)

plt.bar(['X', 'Y', 'Z'], avg_dims)
plt.ylabel('Average size')
plt.title('Average dimension of scans in paired_scans')
plt.show()