In [1]:
import os
import nrrd
import numpy as np
import tifffile as tif

from octvision3d.utils import get_filenames, create_dataset_dirs, save_json, generate_dataset_json

In [2]:
# dir_path = "/data/dkermany_data/3D-OCT/first-batch-labeled"
dir_path = "/Users/danielkermany/Downloads/ERM-3/"
output_path = os.path.join(dir_path, "nnUNet_Dataset")

In [4]:
create_dataset_dirs(output_path)
imagesTr = os.path.join(output_path, "imagesTr")
labelsTr = os.path.join(output_path, "labelsTr")

labels_dict = {
    "CNV": 1,
    "DRU": 2,
    "EX": 3,
    "FLU": 4,
    "GA": 5,
    "HEM": 6,
    "RPE": 7,
    "RET": 8,
    "CHO": 9,
    "VIT": 10,
    "HYA": 11,
    "SHS": 12,
    "ART": 13,
    "ERM": 14,
    "SES": 15
}

In [5]:
vol_paths = [i for i in get_filenames(dir_path, ext="tif") if "slo" not in i]
seg_paths = [i for i in get_filenames(dir_path, ext="seg.nrrd") if "slo" not in i]
for vol_path, seg_path in zip(vol_paths, seg_paths):
    assert vol_path.split(".")[0] == seg_path.split(".")[0]
    vol_name = os.path.splitext(os.path.basename(vol_path))[0]
    seg_name = os.path.basename(vol_path).split(".")[0]

    # Load TIFF volume and seg.nrrd labels
    vol = tif.imread(vol_path)
    bitmap, header = nrrd.read(seg_path)

    # One-Hot Bitmap to Labels
    # Flipping array from (X, Y, Z) to (Z, Y, X)
    labels = np.argmax(bitmap, axis=0).T
    
    # Save spacing json
    save_json({"spacing": [81.0, 1.0, 2.9]}, os.path.join(imagesTr, f"{vol_name}.json"))
    save_json({"spacing": [81.0, 1.0, 2.9]}, os.path.join(labelsTr, f"{seg_name}.json"))

    # Save vol and label tiff images
    output_tif = os.path.join(imagesTr, f"{vol_name}_0000.tif")
    output_labels = os.path.join(labelsTr, f"{seg_name}.tif")
    tif.imwrite(output_tif, vol, photometric='minisblack')
    tif.imwrite(output_labels, labels, photometric='minisblack')

generate_dataset_json(output_path, 
                      channel_names={"0": "OCT"},
                      labels=labels_dict,
                      file_ending=".tif",
                      num_training_cases=len(vol_paths),
                      dataset_name="3D OCT Dataset")