## Variables

In [1]:
data_path = "/home/ismail/projet_PFE/Hands-on-nnUNet/nnUNetFrame/dataset/RawData"

## Functions

In [None]:
def organ_mapper():
    label_dict = {
        0: "background",
        1: "spleen",
        2: "right kidney",
        3: "left kidney",
        4: "gallbladder",
        5: "esophagus",
        6: "liver",
        7: "stomach",
        8: "aorta",
        9: "inferior vena cava",
        10: "portal vein & splenic vein",
        11: "pancreas",
        12: "right adrenal gland",
        13: "left adrenal gland",
    }
    key_to_organ = label_dict
    organ_to_key = {v.lower(): k for k, v in label_dict.items()}
    return key_to_organ, organ_to_key


def transform_to_csv(data_root):
    """
    Transform the BTCV dataset to a csv file.
    Args:
        data_root (str): path to the BTCV dataset.

    Returns:
        annotations (list): list of dictionaries containing the images and their masks with a list of organs available in each scan.

    """

    annotations = []
    images_dir = os.path.join(data_root, "imagesTr")
    masks_dir = os.path.join(data_root, "labelsTr")
    imgs = [images_dir + "/" + f for f in sorted(os.listdir(images_dir))]
    masks = [masks_dir + "/" + f for f in sorted(os.listdir(masks_dir))]

    for img_path, mask_path in zip(imgs, masks):
        mask_nifti = nib.load(mask_path)
        mask_data = mask_nifti.get_fdata()
        organs = [int(organ) for organ in np.unique(mask_data).tolist()]
        annotations.append(
            {"img_path": img_path, "mask_path": mask_path, "organs": organs}
        )

    annotations_df = pd.DataFrame(annotations)
    # Save to CSV
    annotations_df.to_csv(os.path.join(data_root, "dataset.csv"), index=False, sep=",")

    return annotations

In [None]:
def clip_and_rescale(ct_array, min_hu=-175, max_hu=250):
    clipped = np.clip(ct_array, min_hu, max_hu)
    return (clipped + abs(min_hu)) / (max_hu + abs(min_hu))

def to_tensor(image, mask):
    img_t = torch.from_numpy(image).float().unsqueeze(0)  # or permute channels
    msk_t = torch.from_numpy(mask).float().unsqueeze(0)
    return img_t, msk_t

def save_as_png(nii_path, output_dir, z):
    """
    Save the input in the output directory as png file.
    Args:
        nii_path (string): path of the input
        output_dir (string): path of the output directory
        z (int): index of the slice
    """
    base_name = os.path.basename(nii_path).replace(".nii.gz", "")
    slice_name = f"{base_name}_slice{z:03d}.png"
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, slice_name)

    return output_path

# 2D Dataset

In [None]:
def load_n_slices_per_organ(annotations ,organs, n_slices):
    """
    Load a specific number of slices from the BTCV dataset.
    Args:
        n_slices (int): Number of slices to load.

    Returns:
        slices_df (list): List of loaded slices.
    """
    slices_df = []
    key_to_organ, organ_to_key = organ_mapper()
    for annotation in annotations:
        for organ in organs:
            organ_key = organ_to_key.get(organ)
            if organ_key not in annotation["organs"]:
                print(f"Organ {organ} not found in annotation for {annotation['img_path']}. Skipping.")
                continue
            img_path = annotation["img_path"]
            mask_path = annotation["mask_path"]
            img = sitk.ReadImage(img_path)
            mask = sitk.ReadImage(mask_path)
            img_arr = sitk.GetArrayFromImage(img)  # Shape: (Depth, Height, Width)
            mask_arr = sitk.GetArrayFromImage(mask)  # Shape: (Depth, Height, Width)

            
            slices_with_organ = np.any(mask_arr == organ_key, axis=(1, 2))
            slice_indices = np.where(slices_with_organ)[0].tolist()
            middle_slice = (slice_indices[-1] + slice_indices[0]) // 2
            slices_counter = 0
            for i in range(middle_slice - round((n_slices) / 2), middle_slice + math.ceil(n_slices / 2) + 1):
                slices_df.append(
                    {
                        "img_path": img_path,
                        "mask_path": mask_path,
                        "slice_index": i,
                        "organ": organ,
                    }
                )
                slices_counter += 1
                if slices_counter >= n_slices:
                    break
                
    slices_df = pd.DataFrame(slices_df)
    return slices_df
