In [1]:
import os
import json
import numpy as np
import pandas as pd
import nibabel as nib
from tqdm import tqdm
import SimpleITK as sitk

In [None]:
def load_brats23(path):
    def reorient(img_sitk, tgt='RPI'):
        """
        Reorientation from src -> tgt for the input img.
        Although this function is flexible enough for tgt,
        it is important to follow the standard orientation order as:
        'RPI' for Python; 'LPS' for 3D Slicer.
        Parameters:
            img: An sitk image of shape [x, y, z].
            tgt: A string of target orentations.
        Returns:
            img: An sitk image after transposing.
        """
        orienter = sitk.DICOMOrientImageFilter()
        orienter.SetDesiredCoordinateOrientation(tgt)
        return orienter.Execute(img_sitk)
    
    
    img_sitk = sitk.ReadImage(path)
    img_sitk = reorient(img_sitk, tgt='RPI')
    img_arr = sitk.GetArrayFromImage(img_sitk) # x, y, z -> z, y, x (d, h, w)
    return img_arr


def load_dwi(path):
    def transpose_raw2rpi(img, orientation):
        orientation_transpose = []
        try:
            orientation_transpose.append(orientation.index('R'))
        except ValueError:
            orientation_transpose.append(orientation.index('L'))
        try:
            orientation_transpose.append(orientation.index('P'))
        except ValueError:
            orientation_transpose.append(orientation.index('A'))
        try:
            orientation_transpose.append(orientation.index('I'))
        except ValueError:
            orientation_transpose.append(orientation.index('S'))
        img = np.transpose(img, orientation_transpose)
        if 'L' in orientation:
            img = img[::-1, :, :]
        if 'A' in orientation:
            img = img[:, ::-1, :]
        if 'S' in orientation:
            img = img[:, :, ::-1]
        return img, orientation_transpose


    def compute_spacing(affine, orientation_transpose):
        spacing = np.sqrt(np.sum(affine[:3, :3] ** 2, axis=0))
        spacing = spacing[orientation_transpose]
        return spacing


    def transpose_lps2dhw(img, spacing):
        if np.unique(spacing).size == 1: # isotropic data
            img = np.transpose(img, (2, 1, 0)) # axial
        else:
            _max_spacing_side = np.argmax(spacing, axis=-1)
            if _max_spacing_side == 0: # sagittal
                img = np.transpose(img, (0, 2, 1))
            if _max_spacing_side == 1: # coronal
                img = np.transpose(img, (1, 2, 0))
            if _max_spacing_side == 2: # axial
                img = np.transpose(img, (2, 1, 0))
        return img
    
    nifti_file = nib.load(path)
    img = nifti_file.get_fdata()
    affine = nifti_file.affine
    orientation = nib.orientations.aff2axcodes(affine)
    img, orientation_transpose = transpose_raw2rpi(img[:, :, :, 0], orientation)
    spacing = compute_spacing(affine, orientation_transpose)
    img = transpose_lps2dhw(img, spacing)
    return img


def brats23_annotation(dataset):
    root_path='/data/brats2023'
    save_path='/data/pub_brain_5/brats23/'
    if dataset == 'BraTS-GLI':
        tumor_type = 'adult_glioma'
    elif dataset == 'BraTS-MEN':
        tumor_type = 'adult_meningioma'
    elif dataset == 'BraTS-MET':
        tumor_type = 'adult_metastasis'
    elif dataset == 'BraTS-PED':
        tumor_type = 'pediatric_glioma'
    
    rows = []
    data_dir = os.path.join(root_path, dataset)
    for dir in sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]):    
        if 'Train' in dir:
            save_dir = os.path.join(save_path, 'train', tumor_type)
        elif 'Validation' in dir:
            save_dir = os.path.join(save_path, 'test', tumor_type)
        else:
            continue

        set_dir = os.path.join(data_dir, dir)
        print("Process:", set_dir)
        patient_ids = sorted([d for d in os.listdir(set_dir) if os.path.isdir(os.path.join(set_dir, d))])
        for patient_id in tqdm(patient_ids):
            patient_dir = os.path.join(set_dir, patient_id)
            segment_path = None
            for p in os.listdir(patient_dir):
                if os.path.isfile(os.path.join(patient_dir, p)) and p.endswith('.nii.gz') and not p.startswith('.') and ('seg' in p.lower() or 'mask' in p.lower()):
                    segment_path = os.path.join(patient_dir, p)
            if segment_path is None:
                continue
            seg_arr = load_brats23(segment_path)
            non_zero_slices = np.where(np.any(seg_arr != 0, axis=(1, 2)))[0]
            row = [str(os.path.join(save_dir, patient_id)), non_zero_slices]
            rows.append(row)
    return rows
    

def stroke_annotation():
    root_path = '/data/stroke/'
    save_dir = '/data/pub_brain_5/stroke/'
    study_ids = os.listdir(save_dir)
    
    rows = []
    not_valid = []
    for study_id in tqdm(study_ids):
        meta_json = os.path.join(root_path, study_id, f'{study_id}_meta.json')
        with open(meta_json, 'r') as file:
            meta_data = json.load(file)
        meta_study_path = meta_data['OriginalStudyPath']
        dwi_masks_dir = meta_study_path.replace('raw_data', 'DWI_masks')
        segment_path = None
        for p in os.listdir(dwi_masks_dir):
            if os.path.isfile(os.path.join(dwi_masks_dir, p)) and p.endswith('.nii.gz') and not p.startswith('.') and 'stroke_mask' in p.lower():
                segment_path = os.path.join(dwi_masks_dir, p)
        if segment_path == None:
            not_valid.append(segment_path)
            continue
        seg_arr = load_dwi(segment_path)
        non_zero_slices = np.where(np.any(seg_arr != 0, axis=(1, 2)))[0]
        non_zero_slices_flip = non_zero_slices + seg_arr.shape[0]  # shift indices for the second volume
        non_zero_slices_total = np.concatenate([non_zero_slices, non_zero_slices_flip], axis=0)
        row = [str(os.path.join(save_dir, study_id)), non_zero_slices_total]
        rows.append(row)
        
    print(f"# valid studies: {len(rows)}; # invalid studies: {len(not_valid)}")
    return rows

In [None]:
# main
columns = ['study', 'target_slices']

rows = []
for dataset in ['BraTS-GLI', 'BraTS-MEN', 'BraTS-MET', 'BraTS-PED']:
    rows.extend(brats23_annotation(dataset))
rows.extend(stroke_annotation())

df = pd.DataFrame(rows, columns=columns)
df.to_csv('./ground_truth.csv', index=False)
print(f"CSV file created: ./ground_truth.csv")