In [2]:
## Libaries ##
import matplotlib.pyplot as plt
import os, glob
import pandas as pd
import numpy as np
from IPython.display import display
import nibabel as nib
import cv2
import tensorflow as tf
from tensorflow import keras
import albumentations as A

# (1) Save 3D Volumes Into 2D Slices 

In [7]:
# Constants
SLICE_X = False
SLICE_Y = False
SLICE_Z = True
### Functions ###
# 0-1 Normalize Image Intensity Range
def normalizeImageIntensityRange(img): # 0-1 Normalize intensity range of image
    # img = img / np.max(img) 
    img_min = np.min(img)
    img_max = np.max(img)
    img_range = img_max - img_min
    return (img - img_min) / img_range

# Read image or mask volume
def loadVolume(volumePath, binarize_mask=False, normalize=False):
    volume = nib.load(volumePath).get_fdata()
    if binarize_mask == True: #for Mask/Labels
        volume[volume > 1] = 1 #values of 2.0 are converted to 1.0 so that just binary mask
    if normalize == True:
        volume = normalizeImageIntensityRange(volume)
    return volume

# Save volume slice to file
def saveSlice(vol_slice, fname, path):
    image = np.uint8(vol_slice * 255)
    fout = os.path.join(path, f'{fname}.png')
    if os.path.exists(fout):
        print(f"Overwriting file: [+] Slice saved: {fout}", end="\r")
    else:
        print(f"[+] Slice saved: {fout}", end="\r")  
    cv2.imwrite(fout, image)

# Slice image in all directions and save
def sliceAndSaveVolumeImage(vol, fname, path):
    (dimx, dimy, dimz) = vol.shape
    cnt = 0
    if SLICE_X == True:
        cnt += dimx
        #print('Slicing X: ')
        for d in range(dimx):
            saveSlice(vol_slice=vol[d,:,:], fname=fname[:20]+f"_slice{str(d).zfill(3)}-x"+fname[20:], path=path)
    if SLICE_Y == True:
        cnt += dimy
        #print('Slicing Y: ')
        for d in range(dimy):
            saveSlice(vol_slice=vol[:,d,:], frame=fname[:20]+f"_slice{str(d).zfill(3)}-y"+fname[20:], path=path)
    if SLICE_Z == True:
        cnt += dimz
        #print('Slicing Z: ')
        for d in range(dimz):
            saveSlice(vol_slice=vol[:,:,d], fname=fname[:20]+f"_slice{str(d).zfill(3)}-z"+fname[20:], path=path)
    return cnt


## - Training Slices

In [None]:
### TRAINING DATA: Read and process image and mask volumes ###
path = 'data/volumes/training_volumes'
image_files = sorted(glob.glob(f'{path}/img/*avw.nii.gz')) # Get a list of all image files
label_files = sorted(glob.glob(f'{path}/mask/*label.nii.gz')) # Get a list of all label files
assert len(image_files) == len(label_files), "Mismatch in number of image and label files" # Ensure that the number of image files and label files are the same

count=0
for subnum in range(len(image_files)):
    # load MRI
    imgpath = image_files[subnum]
    lblpath = label_files[subnum]
    img = loadVolume(imgpath, normalize=True)
    lbl = loadVolume(lblpath, binarize_mask=True)
    print (f'Image[{str(subnum).zfill(2)}] - {type(img)} {img.shape}    min/max = {np.min(img)} < {np.max(img)}    {imgpath}')
    print (f'Label[{str(subnum).zfill(2)}] - {type(lbl)} {lbl.shape}    min/max = {np.min(lbl)} < {np.max(lbl)}    {lblpath}')
    
    # Slicing and saving
    if img.shape[-1] >= 14 and img.shape == lbl.shape:#excluding scans with dim3 < 14 and make sure img&lbl is same shape
        #images.append(img[:,:,:14])#example of also Cropping: slice dim3 <= 14 
        Volume_image_fname = 'trnVolume'+str(subnum).zfill(3)+'__IMAGE__'+((imgpath.split('\\'))[1])[:-7]
        Volume_label_fname = 'trnVolume'+str(subnum).zfill(3)+'__LABEL__'+((lblpath.split('\\'))[1])[:-7]
        Icnt = sliceAndSaveVolumeImage(vol=img, fname=Volume_image_fname, path='data/slices/training_slices/img/')# Icnt = sliceAndSaveVolumeImage(vol=img, fname=f"image{str(subnum).zfill(2)}_({iname})", path='data/slices/training_slices/img/')
        Lcnt = sliceAndSaveVolumeImage(vol=lbl, fname=Volume_label_fname, path='data/slices/training_slices/mask/')# Lcnt = sliceAndSaveVolumeImage(vol=lbl, fname=f"label{str(subnum).zfill(2)}_({lname})", path='data/slices/training_slices/mask/')
        # print(f"\n{'Volume'+str(subnum).zfill(3)+'__'}... Lcnt({Lcnt}) & Icnt({Icnt}) slices created for images and labels respectivley\n")
        count += Icnt
    else:
        print (f'Image[{str(subnum).zfill(2)}] - SKIPPED SLICING & SAVING VOLUMES')
print("\n\n\nAmount of Training Slices (2D images) =", count)

## - Testing Slices

In [None]:
### TEST DATA: Read and process image and mask volumes ###
path = 'data/volumes/test_volumes'
image_files = sorted(glob.glob(f'{path}/img/*avw.nii.gz')) # Get a list of all image files
label_files = sorted(glob.glob(f'{path}/mask/*.nii.gz')) # Get a list of all label files
assert len(image_files) == len(label_files), "Mismatch in number of image and label files" # Ensure that the number of image files and label files are the same
count=0
for subnum in range(len(image_files)):
    # load MRI
    imgpath = image_files[subnum]
    lblpath = label_files[subnum]
    img = loadVolume(imgpath, normalize=True)
    lbl = loadVolume(lblpath, binarize_mask=True)
    print (f'Image[{str(subnum).zfill(2)}] - {type(img)} {img.shape}    min/max = {np.min(img)} < {np.max(img)}    {imgpath}')
    print (f'Label[{str(subnum).zfill(2)}] - {type(lbl)} {lbl.shape}    min/max = {np.min(lbl)} < {np.max(lbl)}    {lblpath}')
    
    # Slicing and saving
    if img.shape[-1] >= 14 and img.shape == lbl.shape:#excluding scans with dim3 < 14 and make sure img&lbl is same shape
        Volume_image_fname = 'tstVolume'+str(subnum).zfill(3)+'__IMAGE__'+((imgpath.split('\\'))[1])[:-7]
        Volume_label_fname = 'tstVolume'+str(subnum).zfill(3)+'__LABEL__'+((lblpath.split('\\'))[1])[:-7]
        Icnt = sliceAndSaveVolumeImage(vol=img, fname=Volume_image_fname, path='data/slices/test_slices/img/')# Icnt = sliceAndSaveVolumeImage(vol=img, fname=f"image{str(subnum).zfill(2)}_({iname})", path='data/slices/training_slices/img/')
        Lcnt = sliceAndSaveVolumeImage(vol=lbl, fname=Volume_label_fname, path='data/slices/test_slices/mask/')# Lcnt = sliceAndSaveVolumeImage(vol=lbl, fname=f"label{str(subnum).zfill(2)}_({lname})", path='data/slices/training_slices/mask/')
        # print(f"\n{'Volume'+str(subnum).zfill(3)+'__'}... Lcnt({Lcnt}) & Icnt({Icnt}) slices created for images and labels respectivley\n")
        count += Icnt
    else:
        print (f'Image[{str(subnum).zfill(2)}] - SKIPPED SLICING & SAVING VOLUMES')
print("\n\n\nAmount of Test Slices (2D images) =", count)

# (2) Create Set of Augmentation Slices

In [None]:
### CREATING AUGMENTATION FILES ##
original_images_path = 'data/slices/training_slices/img'
original_masks_path = 'data/slices/training_slices/mask'
augmented_images_path = 'data/slices/training_slices/augimg'
augmented_masks_path = 'data/slices/training_slices/augmask'
os.makedirs(augmented_images_path, exist_ok=True)
os.makedirs(augmented_masks_path, exist_ok=True)

# Augmentation pipeline
augmentations = A.Compose([A.HorizontalFlip(p=0.5),
                           A.VerticalFlip(p=0.5),
                           A.RandomBrightnessContrast(p=0.2)])

# Load and augment images
image_filenames = sorted(os.listdir(original_images_path))
mask_filenames = sorted(os.listdir(original_masks_path))
assert len(image_filenames) == len(mask_filenames), "Mismatch in number of image and mask files"

for img_file, mask_file in zip(image_filenames, mask_filenames):
    img = cv2.imread(os.path.join(original_images_path, img_file), cv2.IMREAD_GRAYSCALE)
    mask = cv2.imread(os.path.join(original_masks_path, mask_file), cv2.IMREAD_GRAYSCALE)
    
    # Convert images to float32 numpy arrays
    img = img.astype(np.float32) / 255.0  # Normalize images to [0, 1]
    mask = mask.astype(np.float32) / 255.0  # Normalize masks to [0, 1]
    
    for i in range(5):  # Generate 5 augmented versions of each image
        augmented = augmentations(image=img, mask=mask)
        aug_img, aug_mask = augmented['image'], augmented['mask']
        aug_img_filename = f'aug_{i}_{img_file}'
        aug_mask_filename = f'aug_{i}_{mask_file}'
        
        # Convert images back to uint8 for saving
        aug_img = (aug_img * 255).astype(np.uint8)
        aug_mask = (aug_mask * 255).astype(np.uint8)
        cv2.imwrite(os.path.join(augmented_images_path, aug_img_filename), aug_img)
        cv2.imwrite(os.path.join(augmented_masks_path, aug_mask_filename), aug_mask)

print('Augmentation complete. Check the augmented images and masks at the specified paths.')