# CryoET Dataset Data Pre-Processing

Requirements:

- Raw tomogram in MRC format processed through Fiji ImageJ (histogram equalisation) and converted to NifTi format
    
- Membrane segmentation CMM format (segmentation coordinates)

In [1]:
# Import modules

import os
import numpy as np

import nibabel as nib
import json

import gzip
import shutil

### Extract 3D Patches

In [2]:
#modified to cater origin mismatch and works (final one for all the slices)
def extract_2D_3D_patchvolumes(raw_file, label_file, output_dir3D, start_name, patch_size=128, stride=50):
    """
    Extracts 3D patch volumes from raw and label NIfTI images and saves them with origins matched to the original volumes.

    Args:
        raw_file (str): Path to the raw NIfTI file.
        label_file (str): Path to the label NIfTI file.
        output_dir3D (str): Directory where the 3D patch volumes will be saved.
        start_name (str): Prefix for the saved patch files.
        patch_size (int, optional): Size of each patch (default is 128).
        stride (int, optional): Stride for patch extraction (default is 50).

    Returns:
        None

    Notes:
        - The created patches are in .nii.gz compressed format. Use the uncompress_nii_gz_files to create .nii files that can 
          be used in the UNet model.

    This function performs the following steps:
        1. Loads the raw and label NIfTI images.
        2. Retrieves the origin of the raw image.
        3. Creates the output directory if it doesn't exist.
        4. Iterates over each slice of the raw image.
        5. Extracts 2D patches from each slice based on the given patch size and stride.
        6. Ensures that patches are extracted from unique coordinates.
        7. Checks if the patch contains at least 50 labeled pixels.
        8. Initializes arrays for storing 3D patches.
        9. Crops the same portion across all z-slices to form 3D patches.
        10. Sets the origin of the cropped volumes to match the original volume.
        11. Saves the cropped raw and label patches as NIfTI files.
        12. Records patch information and increments the patch count.
        13. Compares and verifies the origin of each patch with the original volume.
        14. Saves patch information in a JSON file.
        15. Prints the number of saved patch files, the total number of cropped patches, and the number of patches with matched origins.
    """
    # Load the raw image and label data
    raw_img = nib.load(raw_file)
    label_img = nib.load(label_file)
    
    # Get the origin values for the original volume
    raw_origin = raw_img.affine[:3, 3]
    
    # Create the output directory for 2D patches if it doesn't exist
    if not os.path.exists(output_dir3D):
        os.makedirs(output_dir3D)

    # Initialize a list to store patch information
    patchvol_info = []

    # Keep track of unique coordinates
    unique_coordinates = set()

    # Initialize counter for the number of patch volumes cropped
    patchvol_count = 0
    origin_match_count = 0

    # Loop through each slice
    for z in range(raw_img.shape[2]):
        # Get the current slice
        raw_slice = raw_img.get_fdata()[:, :, z]
        label_slice = label_img.get_fdata()[:, :, z]

        # Loop through the image and extract patches
        for y in range(0, raw_slice.shape[0] - patch_size + 1, stride):
            for x in range(0, raw_slice.shape[1] - patch_size + 1, stride):
                # Check if the coordinates have already been processed
                if (y, x) in unique_coordinates:
                    continue

                # Extract the patch
                patch = raw_slice[y:y+patch_size, x:x+patch_size]
                patch_offset = label_slice[y:y+patch_size, x:x+patch_size]

                # Check if the patch contains at least 50 pixels of labels
                if np.sum(patch_offset == 1) < 50:
                    continue

                # Initialize arrays to store cropped raw and label volumes
                cropped_raw = np.zeros((patch_size, patch_size, 400))
                cropped_label = np.zeros((patch_size, patch_size, 400))

                # Crop the same portion across all z-slices
                for zz in range(raw_img.shape[2]):
                    cropped_raw[:, :, zz] = raw_img.get_fdata()[:, :, zz][y:y+patch_size, x:x+patch_size]
                    cropped_label[:, :, zz] = label_img.get_fdata()[:, :, zz][y:y+patch_size, x:x+patch_size]

                # Set the origin of the cropped volumes to match the origin of the original volume
                cropped_raw_affine = raw_img.affine.copy()
                cropped_raw_affine[:3, 3] = raw_origin
                
                cropped_label_affine = label_img.affine.copy()
                cropped_label_affine[:3, 3] = raw_origin

                cropped_raw_img = nib.Nifti1Image(cropped_raw, cropped_raw_affine, raw_img.header)
                cropped_label_img = nib.Nifti1Image(cropped_label, cropped_label_affine, label_img.header)

                patch_file = os.path.join(output_dir3D, f'{start_name}z{z}y{y}x{x}.nii.gz')
                offset_file = os.path.join(output_dir3D, f'{start_name}z{z}y{y}x{x}_gt.nii.gz')

                nib.save(cropped_raw_img, patch_file)
                nib.save(cropped_label_img, offset_file)

                # Store patch information
                patchvol_info.append({'z': z, 'x': x, 'y': y})

                # Add the coordinates to the unique_coordinates set
                unique_coordinates.add((y, x))

                # Increment the patch volume count
                patchvol_count += 1

                # Compare the origin values for the patch and original volume
                patch_origin = cropped_raw_img.affine[:3, 3]
                if np.array_equal(patch_origin, raw_origin):
                    origin_match_count += 1
                    #print(f"Origin values match for patch: {patch_file}")
                else:
                    print(f"Origin values mismatch for patch: {patch_file}")
                    print(f"Patch origin: {patch_origin}")
                    print(f"Original origin: {raw_origin}")

    # Save patch information as a JSON file
    patch_info_file = os.path.join(output_dir3D, f'patchvol{start_name}.json')
    with open(patch_info_file, 'w') as f:
        json.dump(patchvol_info, f)

    # Print the number of files in output_dir3D
    file_count_3D = len(os.listdir(output_dir3D))
    print(f"Number of files in output_dir3D: {file_count_3D}")

    # Print the number of patch volumes cropped
    print(f"Number of patch volumes cropped: {patchvol_count}")

    # Print the number of patch volumes with origin matched
    print(f"Number of patch volumes with origin matched: {origin_match_count}")

In [3]:
raw_file = r"1.1_folder/Pos_3_6_9_test_data_2/Pos_3_6_9_52Apx_rawEqualizedvolume.nii"
label_file = r"1.1_folder/Pos_3_6_9_test_data_2/Pos_3_6_9_binary_label_equalized.nii"
output_dir3D = r"1.2_folder/Pos_3_6_9_test_data_3D_patches"
start_name = "Pos_3_6_9"

extract_2D_3D_patchvolumes(raw_file = raw_file, label_file = label_file, output_dir3D = output_dir3D, start_name = start_name, patch_size=128, stride=50)

Number of files in output_dir3D: 647
Number of patch volumes cropped: 323
Number of patch volumes with origin matched: 323


### Uncompress .nii.gz files and store tomogram and label patches in separate folders

In [4]:
def uncompress_nii_gz_files(input_dir, tomogram_output_dir, label_output_dir):
    """
    Uncompresses .nii.gz files in the input directory and stores them in the designated tomogram and label directories.

    Args:
        input_dir (str): Directory containing the compressed .nii.gz files.
        tomogram_output_dir (str): Directory to store the uncompressed tomogram .nii files.
        label_output_dir (str): Directory to store the uncompressed label .nii files.

    Returns:
        None
    """
    # Create output directories if they don't exist
    if not os.path.exists(tomogram_output_dir):
        os.makedirs(tomogram_output_dir)
    
    if not os.path.exists(label_output_dir):
        os.makedirs(label_output_dir)
    
    # Iterate through the files in the input directory
    for filename in os.listdir(input_dir):
        if filename.endswith(".nii.gz"):
            input_file_path = os.path.join(input_dir, filename)
            
            # Determine if the file is a tomogram or label based on the filename
            if filename.endswith("_gt.nii.gz"):
                output_file_path = os.path.join(label_output_dir, filename[:-3])  # Remove the .gz extension
            else:
                output_file_path = os.path.join(tomogram_output_dir, filename[:-3])  # Remove the .gz extension
            
            # Uncompress the .nii.gz file
            with gzip.open(input_file_path, 'rb') as f_in:
                with open(output_file_path, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
            print(f"Uncompressed and saved {filename} to {output_file_path}")


In [5]:
# Define the directory paths
input_dir = "1.2_folder/Pos_3_6_9_test_data_3D_patches"
tomogram_output_dir = "1.2_folder/Pos_3_6_9_tomogram_patches"
label_output_dir = "1.2_folder/Pos_3_6_9_label_patches"

# Call the function to uncompress files and store them in the designated directories
uncompress_nii_gz_files(input_dir, tomogram_output_dir, label_output_dir)

Uncompressed and saved Pos_3_6_9z108y150x850.nii.gz to 1.2_folder/Pos_3_6_9_tomogram_patches\Pos_3_6_9z108y150x850.nii
Uncompressed and saved Pos_3_6_9z108y150x850_gt.nii.gz to 1.2_folder/Pos_3_6_9_label_patches\Pos_3_6_9z108y150x850_gt.nii
Uncompressed and saved Pos_3_6_9z108y200x850.nii.gz to 1.2_folder/Pos_3_6_9_tomogram_patches\Pos_3_6_9z108y200x850.nii
Uncompressed and saved Pos_3_6_9z108y200x850_gt.nii.gz to 1.2_folder/Pos_3_6_9_label_patches\Pos_3_6_9z108y200x850_gt.nii
Uncompressed and saved Pos_3_6_9z108y250x850.nii.gz to 1.2_folder/Pos_3_6_9_tomogram_patches\Pos_3_6_9z108y250x850.nii
Uncompressed and saved Pos_3_6_9z108y250x850_gt.nii.gz to 1.2_folder/Pos_3_6_9_label_patches\Pos_3_6_9z108y250x850_gt.nii
Uncompressed and saved Pos_3_6_9z110y100x750.nii.gz to 1.2_folder/Pos_3_6_9_tomogram_patches\Pos_3_6_9z110y100x750.nii
Uncompressed and saved Pos_3_6_9z110y100x750_gt.nii.gz to 1.2_folder/Pos_3_6_9_label_patches\Pos_3_6_9z110y100x750_gt.nii
Uncompressed and saved Pos_3_6_9z110

### Plot Patches and Tomograms

In [6]:
def plot_2D_3D_patch_volumes(output_dir_3D, output_dir_2D, volume_index=0):
    # find all 3D patch volumes
    raw_files = sorted(glob.glob(os.path.join(output_dir_3D, "patch_3D_raw_y*_x*.nii")))
    label_files = sorted(glob.glob(os.path.join(output_dir_3D, "patch_3D_label_y*_x*.nii")))
    # extract the y and x coordinates from the filenames
    coordinates = []
    for raw_file in raw_files:
        filename = os.path.basename(raw_file)
        y, x = re.findall(r"_y(\d+)_x(\d+).nii", filename)[0]
        coordinates.append((int(y), int(x)))
    # sort the raw and label filenames and coordinates based on the y and x coordinates
    sorted_raw_files_and_coordinates = sorted(zip(raw_files, coordinates), key=lambda x: x[1])
    sorted_label_files_and_coordinates = sorted(zip(label_files, coordinates), key=lambda x: x[1])
    sorted_raw_files, sorted_coordinates = zip(*sorted_raw_files_and_coordinates)
    sorted_label_files, _ = zip(*sorted_label_files_and_coordinates)
    # find the index of the corresponding 3D patch file
    y_coord, x_coord = sorted_coordinates[volume_index]
    # load the specified 3D patch volume
    print(sorted_raw_files[volume_index], sorted_coordinates[volume_index], sorted_label_files[volume_index])
    raw_img = nib.load(sorted_raw_files[volume_index])
    raw_data = raw_img.get_fdata()
    label_img = nib.load(sorted_label_files[volume_index])
    label_data = label_img.get_fdata()
    # find the index of the corresponding 2D patch file
    index_2D = None
    for i, coord in enumerate(coordinates):
        if coord == (y_coord, x_coord):
            index_2D = i
            break
    if index_2D is None:
        print("Matching 2D patch file not found.")
        return
    # load the corresponding 2D raw and label patches
    raw_2Dfiles = sorted(glob.glob(os.path.join(output_dir_2D, "patch_raw_y*_x*.nii")))
    label_2Dfiles = sorted(glob.glob(os.path.join(output_dir_2D, "patch_label_y*_x*.nii")))
    # extract the y and x coordinates from the filenames
    coordinates_2D = []
    for filename in raw_2Dfiles:
        y, x = re.findall(r"_y(\d+)_x(\d+).nii", filename)[0]
        coordinates_2D.append((int(y), int(x)))
    # find the index of the corresponding 2D patch file
    index_2D = coordinates_2D.index((y_coord, x_coord))
    raw_2Ddata = nib.load(raw_2Dfiles[index_2D]).get_fdata()
    label_2Ddata = nib.load(label_2Dfiles[index_2D]).get_fdata()
    print(raw_2Dfiles[index_2D], label_2Dfiles[index_2D], (y_coord, x_coord))

    def update(z):
        # plot the 3D and 2D patch volumes
        fig, axs = plt.subplots(2, 2, figsize=(14, 14))
        axs[1, 0].imshow(raw_2Ddata, cmap='gray')
        axs[1, 0].set_title(f"Raw Patch (y={y_coord}, x={x_coord})")
        axs[1, 1].imshow(label_2Ddata, cmap='gray')
        axs[1, 1].set_title(f"Label Patch (y={y_coord}, x={x_coord})")
        axs[0, 0].imshow(raw_data[:, :, z], cmap='gray')
        axs[0, 0].set_title(f"Raw Volume (y={y_coord}, x={x_coord})")
        axs[0, 1].imshow(label_data[:, :, z], cmap="gray", vmin=0, vmax=1)
        axs[0, 1].set_title(f"Label Volume (y={y_coord}, x={x_coord})")
        plt.show()

    # create the slider and play button
    z_slider = widgets.IntSlider(min=0, max=raw_data.shape[2] - 1, step=1, value=0, description="Slice")
    play = widgets.Play(interval=500, value=0, min=0, max=raw_data.shape[2] - 1, step=1, description="Play")
    widgets.jslink((play, 'value'), (z_slider, 'value'))
    button_layout = widgets.Layout(width='80px')
    play_button = widgets.Button(description=">", layout=button_layout)
    widgets.jslink((play, 'value'), (play_button, 'disabled'))

    def on_play_button_click(b):
        if play.value == play.max:
            play.value = play.min
            play_button.description = ">"
        elif play.playing:
            play.stop()
            play_button.description = "Play"
        else:
            play.play()
            play_button.description = "Pause"

    play_button.on_click(on_play_button_click)
    display(widgets.HBox([z_slider, play, play_button]))
    interact(update, z=z_slider)