Skip to content

Commit

Permalink
Merge pull request #3 from aristizabal95/nnunet
Browse files Browse the repository at this point in the history
Use nnUNet for tumor segmentation
  • Loading branch information
aristizabal95 committed Dec 18, 2023
2 parents 72e9126 + c97597d commit c6f54af
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 46 deletions.
10 changes: 10 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ RUN cp -R /Front-End/bin/install/appdir/usr/bin/data_prep_models /project/stages
# Hotfix: install more recent version of GaNDLF for metrics generation
RUN pip install git+https://github.com/mlcommons/GaNDLF@616b37bafad8f89d5c816a88f44fa30470601311

RUN pip install torch torchvision

RUN pip install git+https://github.com/MIC-DKFZ/nnUNet.git@nnunetv1

RUN mkdir /nnUNet_raw_data_base && mkdir /nnUNet_preprocessed

ENV nnUNet_raw_data_base="/nnUNet_raw_data_base"
ENV nnUNet_preprocessed="/nnUNet_preprocessed"
ENV RESULTS_FOLDER="/project/models/nnUNet_trained_models"

COPY ./mlcubes/data_preparation/project /project

ENTRYPOINT ["python", "/project/mlcube.py"]
3 changes: 2 additions & 1 deletion mlcubes/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
*.png
*/mlcube/workspace/*
!requirements.txt
!*/mlcube/workspace/parameters.yaml
!*/mlcube/workspace/parameters.yaml
models
3 changes: 2 additions & 1 deletion mlcubes/data_preparation/mlcube/mlcube.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ platform:

docker:
# Image name
image: mlcommons/rano-data-prep:latest
image: mlcommons/rano-data-prep:nnunet
# Docker build context relative to $MLCUBE_ROOT. Default is `build`.
build_context: "../project"
# Docker file name within docker build context, default is `Dockerfile`.
Expand All @@ -21,6 +21,7 @@ tasks:
data_path: input_data,
labels_path: input_labels,
parameters_file: parameters.yaml,
models: additional_files/models,
}
outputs: {
output_path: data/,
Expand Down
3 changes: 2 additions & 1 deletion mlcubes/data_preparation/project/mlcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ def prepare(
data_path: str = typer.Option(..., "--data_path"),
labels_path: str = typer.Option(..., "--labels_path"),
parameters_file: str = typer.Option(..., "--parameters_file"),
models_path: str = typer.Option(..., "--models"),
output_path: str = typer.Option(..., "--output_path"),
output_labels_path: str = typer.Option(..., "--output_labels_path"),
report_file: str = typer.Option(..., "--report_file"),
metadata_path: str = typer.Option(..., "--metadata_path"),
):
cmd = f"python3 project/prepare.py --data_path={data_path} --labels_path={labels_path} --data_out={output_path} --labels_out={output_labels_path} --report={report_file} --parameters={parameters_file} --metadata_path={metadata_path}"
cmd = f"python3 project/prepare.py --data_path={data_path} --labels_path={labels_path} --models_path={models_path} --data_out={output_path} --labels_out={output_labels_path} --report={report_file} --parameters={parameters_file} --metadata_path={metadata_path}"
exec_python(cmd)


Expand Down
34 changes: 12 additions & 22 deletions mlcubes/data_preparation/project/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
import argparse
import pandas as pd
import yaml
import shutil
from stages.generate_report import GenerateReport
from stages.get_csv import AddToCSV
from stages.nifti_transform import NIfTITransform
from stages.extract import Extract
from stages.extract_nnunet import ExtractNnUNet
from stages.manual import ManualStage
from stages.comparison import SegmentationComparisonStage
from stages.confirm import ConfirmStage
from stages.split import SplitStage
from stages.pipeline import Pipeline
from stages.constants import INTERIM_FOLDER, FINAL_FOLDER, TUMOR_MASK_FOLDER

MODELS_PATH = "/project/models"


def find_csv_filenames(path_to_dir, suffix=".csv"):
filenames = os.listdir(path_to_dir)
Expand All @@ -27,6 +31,9 @@ def setup_argparser():
parser.add_argument(
"--labels_path", dest="labels", type=str, help="path containing labels"
)
parser.add_argument(
"--models_path", dest="models", type=str, help="path to the nnunet models"
)
parser.add_argument(
"--data_out", dest="data_out", type=str, help="path to store prepared data"
)
Expand Down Expand Up @@ -78,7 +85,7 @@ def init_pipeline(args):
loop = None
report_gen = GenerateReport(out_data_csv, args.data, out_raw, args.labels, args.labels_out, args.data_out, 8, brain_data_out, 3, tumor_data_out, 5)
csv_proc = AddToCSV(out_raw, out_data_csv, valid_data_out, out_raw)
nifti_proc = NIfTITransform(out_data_csv, nifti_data_out, valid_data_out, args.metadata_path)
nifti_proc = NIfTITransform(out_data_csv, nifti_data_out, valid_data_out, args.metadata_path, args.data_out)
brain_extract_proc = Extract(
out_data_csv,
brain_data_out,
Expand All @@ -89,14 +96,12 @@ def init_pipeline(args):
"extract_brain",
3,
)
tumor_extract_proc = Extract(
tumor_extract_proc = ExtractNnUNet(
out_data_csv,
tumor_data_out,
INTERIM_FOLDER,
brain_data_out,
INTERIM_FOLDER,
# loop,
"extract_tumor",
4,
)
manual_proc = ManualStage(out_data_csv, tumor_data_out, tumor_data_out, backup_out)
Expand Down Expand Up @@ -142,24 +147,9 @@ def init_report(args) -> pd.DataFrame:
def main():
args = setup_argparser()

# Check if the input data is already prepared
# If so, just copy the contents and skip all processing
# TODO: this means we won't have a report. What would be the best way
# to handle this?
# TODO: Re-enable this when it is implemented correctly and we see the need for it
# # 1. If there is a csv file in the input folder
# # always reuse it for the prepared dataset
# csvs = find_csv_filenames(args.data_out)
# if len(csvs) == 1:
# # One csv was found. Assume this is the desired csv
# # move it to the expected location
# # TODO: How to deal with inconsistent paths because of MLCube functionality?
# csv_path = os.path.join(args.data_out, csvs[0])
# os.rename(csv_path, out_data_csv)
# # can we assume the paths inside data.csv to be relative to the csv?
# # TODO: Create some logic to turn the csv paths into the expected paths for the MLCube
# # update_csv_paths(out_data_csv)

# Move models to the expected location
if not os.path.exists(MODELS_PATH):
shutil.copytree(args.models, MODELS_PATH)

report = init_report(args)
pipeline = init_pipeline(args)
Expand Down
4 changes: 2 additions & 2 deletions mlcubes/data_preparation/project/stages/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def execute(
"""
self.__prepare_exec()
self.__copy_case(index)
self.__process_case(index)
self._process_case(index)
report, success = self.__update_state(index, report)
self.prep.write()

Expand All @@ -99,7 +99,7 @@ def __copy_case(self, index: Union[str, int]):
for prev, copy in zip(prev_paths, copy_paths):
shutil.copytree(prev, copy, dirs_exist_ok=True)

def __process_case(self, index: Union[str, int]):
def _process_case(self, index: Union[str, int]):
id, tp = get_id_tp(index)
df = self.prep.subjects_df
row_search = df[(df["SubjectID"] == id) & (df["Timepoint"] == tp)]
Expand Down
193 changes: 193 additions & 0 deletions mlcubes/data_preparation/project/stages/extract_nnunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Union, List, Tuple
from tqdm import tqdm
import pandas as pd
import os
from os.path import realpath, dirname, join
import shutil
import time
import SimpleITK as sitk
import subprocess
import traceback
from LabelFusion.wrapper import fuse_images

from .extract import Extract
from .PrepareDataset import (
Preparator,
FINAL_FOLDER,
generate_tumor_segmentation_fused_images,
save_screenshot,
)
from .utils import update_row_with_dict, get_id_tp, MockTqdm

MODALITY_MAPPING = {
"t1c": "t1c",
"t1ce": "t1c",
"t1": "t1n",
"t1n": "t1n",
"t2": "t2w",
"t2w": "t2w",
"t2f": "t2f",
"flair": "t2f",
}

MODALITY_VARIANTS = {
"t1c": "T1GD",
"t1ce": "T1GD",
"t1": "T1",
"t1n": "T1",
"t2": "T2",
"t2w": "T2",
"t2f": "FLAIR",
"flair": "FLAIR",
}


class ExtractNnUNet(Extract):
def __init__(
self,
data_csv: str,
out_path: str,
subpath: str,
prev_stage_path: str,
prev_subpath: str,
status_code: int,
extra_labels_path=[],
):
self.data_csv = data_csv
self.out_path = out_path
self.subpath = subpath
self.data_subpath = FINAL_FOLDER
self.prev_path = prev_stage_path
self.prev_subpath = prev_subpath
os.makedirs(self.out_path, exist_ok=True)
self.prep = Preparator(data_csv, out_path, "BraTSPipeline")
self.pbar = tqdm()
self.failed = False
self.exception = None
self.__status_code = status_code
self.extra_labels_path = extra_labels_path

@property
def name(self) -> str:
return "nnUNet Tumor Extraction"

@property
def status_code(self) -> str:
return self.__status_code

def __get_models(self):
rel_models_path = "../models/nnUNet_trained_models/nnUNet/3d_fullres"
models_path = realpath(join(dirname(__file__), rel_models_path))
return os.listdir(models_path)

def __get_mod_order(self, model):
rel_orders_path = "../models/nnUNet_modality_order"
order_path = realpath(join(dirname(__file__), rel_orders_path, model, "order"))
with open(order_path, "r") as f:
order_str = f.readline()
# remove 'order = ' from the splitted list
modalities = order_str.split()[2:]
modalities = [MODALITY_MAPPING[mod] for mod in modalities]
return modalities

def __prepare_case(self, path, id, tp, order):
tmp_subject = f"{id}-{tp}"
tmp_path = os.path.join(path, "tmp-data")
tmp_subject_path = os.path.join(tmp_path, tmp_subject)
tmp_out_path = os.path.join(path, "tmp-out")
shutil.rmtree(tmp_path, ignore_errors=True)
shutil.rmtree(tmp_out_path, ignore_errors=True)
os.makedirs(tmp_subject_path)
os.makedirs(tmp_out_path)
in_modalities_path = os.path.join(path, "DataForFeTS", id, tp)
input_modalities = {}
for modality_file in os.listdir(in_modalities_path):
if not modality_file.endswith(".nii.gz"):
continue
modality = modality_file[:-7].split("_")[-1]
norm_mod = MODALITY_MAPPING[modality]
mod_idx = order.index(norm_mod)
mod_idx = str(mod_idx).zfill(4)

out_modality_file = f"{tmp_subject}_{mod_idx}.nii.gz"
in_file = os.path.join(in_modalities_path, modality_file)
out_file = os.path.join(tmp_subject_path, out_modality_file)
input_modalities[MODALITY_VARIANTS[modality]] = in_file
shutil.copyfile(in_file, out_file)

return tmp_subject_path, tmp_out_path, input_modalities

def __run_model(self, model, data_path, out_path):
# models are named Task<ID>_..., where <ID> is always 3 numbers
task_id = model[4:7]
cmd = f"nnUNet_predict -i {data_path} -o {out_path} -t {task_id} -f all"
print(cmd)
print(os.listdir(data_path))
start = time.time()
subprocess.call(cmd, shell=True)
end = time.time()
total_time = end - start
print(f"Total time elapsed is {total_time} seconds")

def __finalize_pred(self, tmp_out_path, out_pred_filepath):
# We assume there's only one file in out_path
pred = None
for file in os.listdir(tmp_out_path):
if file.endswith(".nii.gz"):
pred = file

if pred is None:
raise RuntimeError("No tumor segmentation was found")

pred_filepath = os.path.join(tmp_out_path, pred)
shutil.move(pred_filepath, out_pred_filepath)
return out_pred_filepath

def _process_case(self, index: Union[str, int]):
id, tp = get_id_tp(index)
subject_id = f"{id}_{tp}"
models = self.__get_models()
outputs = []
images_for_fusion = []
out_path = os.path.join(self.out_path, "DataForQC", id, tp)
out_pred_path = os.path.join(out_path, "TumorMasksForQC")
os.makedirs(out_pred_path, exist_ok=True)
for i, model in enumerate(models):
order = self.__get_mod_order(model)
tmp_data_path, tmp_out_path, input_modalities = self.__prepare_case(
self.out_path, id, tp, order
)
out_pred_filepath = os.path.join(
out_pred_path, f"{id}_{tp}_tumorMask_model_{i}.nii.gz"
)
try:
self.__run_model(model, tmp_data_path, tmp_out_path)
output = self.__finalize_pred(tmp_out_path, out_pred_filepath)
outputs.append(output)
images_for_fusion.append(sitk.ReadImage(output, sitk.sitkUInt8))
except Exception as e:
self.exception = e
self.failed = True
self.traceback = traceback.format_exc()
return

# cleanup
shutil.rmtree(tmp_data_path, ignore_errors=True)
shutil.rmtree(tmp_out_path, ignore_errors=True)

fused_outputs = generate_tumor_segmentation_fused_images(
images_for_fusion, out_pred_path, subject_id
)
outputs += fused_outputs

for output in outputs:
# save the screenshot
tumor_mask_id = os.path.basename(output).replace(".nii.gz", "")
save_screenshot(
input_modalities,
os.path.join(
out_path,
f"{tumor_mask_id}_summary.png",
),
output,
)
18 changes: 9 additions & 9 deletions mlcubes/data_preparation/project/stages/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,15 @@ def execute(self, report: pd.DataFrame) -> Tuple[pd.DataFrame, bool]:
# Keep track of the cases that were found on the input folder
observed_cases.add(index)

has_semiprepared, in_tp_path = has_semiprepared_folder_structure(in_tp_path, recursive=True)
if has_semiprepared:
tumor_seg = get_tumor_segmentation(subject, timepoint, in_tp_path)
if tumor_seg is not None:
report = self._proceed_to_comparison(subject, timepoint, in_tp_path, report)
else:
report = self._proceed_to_tumor_extraction(subject, timepoint, in_tp_path, report)
continue

if index in report.index:
# Case has already been identified, see if input hash is different
# if so, override the contents and restart the state for that case
Expand Down Expand Up @@ -373,15 +382,6 @@ def execute(self, report: pd.DataFrame) -> Tuple[pd.DataFrame, bool]:
# Move files around so it has the expected structure
to_expected_folder_structure(out_tp_path, contents_path)

has_semiprepared, in_tp_path = has_semiprepared_folder_structure(in_tp_path, recursive=True)
if has_semiprepared:
tumor_seg = get_tumor_segmentation(subject, timepoint, in_tp_path)
if tumor_seg is not None:
report = self._proceed_to_comparison(subject, timepoint, in_tp_path, report)
else:
report = self._proceed_to_tumor_extraction(subject, timepoint, in_tp_path, report)
continue

if input_is_prepared:
data["status_name"] = "DONE"
data["status_code"] = self.done_status_code
Expand Down

0 comments on commit c6f54af

Please sign in to comment.