In [1]:
%env CUDA_VISIBLE_DEVICES=MIG-3b133eb8-ed94-51f8-937f-cbc3e3f3ff2a
%env CUDA_VISIBLE_DEVICES

env: CUDA_VISIBLE_DEVICES=MIG-3b133eb8-ed94-51f8-937f-cbc3e3f3ff2a


'MIG-3b133eb8-ed94-51f8-937f-cbc3e3f3ff2a'

In [2]:
import argparse
import json
from pathlib import Path

import mlflow
import numpy as np
import torch
from monai.losses import DiceCELoss
from tqdm.notebook import tqdm
import SimpleITK as sitk

import cfg
from dataset_mevis_v2 import MRI_dataset_batched
from dsc import dice_coeff
from funcs import calculate_sensitivity_specificity
from utils import iou_torch

  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)
  return register_model(fn_wrapper)


In [3]:
%env MLFLOW_TRACKING_URI=http://0.0.0.0:5000

env: MLFLOW_TRACKING_URI=http://0.0.0.0:5000


In [None]:
def mevis_args_parser():  
    # changing prompt probability, do not import
    parser = argparse.ArgumentParser(description="Mevis SAM fine-tuning parameters.")
    parser.add_argument(
        "--prompt_probability",
        type=float,
        default=0.3,
        help="probability of generating prompts for each batch",
    )
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        default="",
        help="Checkpoint to continue the training",
    )
    parser.add_argument(
        "--lr_schedule",
        type=bool,
        default=True,
        help="Use earning rate scheduler during training.",
    )
    parser.add_argument(
        "--lr_train_start",
        type=float,
        default=5e-4,
        help="Learning rate starting value during training.",
    )
    parser.add_argument(
        "--lr_train_end",
        type=float,
        default=5e-5,
        help="Learning rate ending value during training.",
    )
    parser.add_argument(
        "--lr_warmup", type=float, default=1e-5, help="Learning rate on warmup."
    )
    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of training epochs"
    )
    parser.add_argument(
        "--epochs_warmup", type=int, default=20, help="Number of warmup epochs"
    )
    parser.add_argument(
        "--batch_size", type=int, default=120, help="Number of slices in a batch"
    )
    mevis_args = parser.parse_args("")

    return mevis_args


mevis_args = mevis_args_parser()
args = cfg.parse_args()
merged_args = argparse.Namespace(**vars(args), **vars(mevis_args))
merged_args.prompt_probability = 1.0

In [5]:
TEST_DATA_FILE = "data_files/Test_data_files_resampled_v2.json"
TRAIN_DATA_FILE = "data_files/Train_data_files_resampled_v2.json"
VALID_DATA_FILE = "data_files/Validation_data_files_resampled_v2.json"
MODEL_NAME = "mevis_sam_v2_epoch151"
MODEL_VERSION = "1"
DEVICE = torch.device("cuda:" + str(args.gpu_device))

In [None]:
model_uri = f"models:/{MODEL_NAME}/{MODEL_VERSION}"
model = mlflow.pytorch.load_model(model_uri)
model.eval()

In [7]:
dataset_test = MRI_dataset_batched(
    merged_args,
    data_file=TEST_DATA_FILE,
    batch_size=mevis_args.batch_size,
    phase="test",
    operation_mode="queue",
    mask_out_size=merged_args.out_size,
    attention_size=64,
    crop=False,
    crop_size=1024,
    cls=1,
    if_prompt=True,
    prompt_type="points",
    if_attention_map=True,
    device=DEVICE,
)
dataset_validation = MRI_dataset_batched(
    merged_args,
    data_file=VALID_DATA_FILE,
    batch_size=mevis_args.batch_size,
    phase="test",
    operation_mode="queue",
    mask_out_size=merged_args.out_size,
    attention_size=64,
    crop=False,
    crop_size=1024,
    cls=1,
    if_prompt=True,
    prompt_type="points",
    if_attention_map=True,
    device=DEVICE,
)
dataset_train = MRI_dataset_batched(
    merged_args,
    data_file=TRAIN_DATA_FILE,
    batch_size=mevis_args.batch_size,
    phase="test",
    operation_mode="queue",
    mask_out_size=merged_args.out_size,
    attention_size=64,
    crop=False,
    crop_size=1024,
    cls=1,
    if_prompt=True,
    prompt_type="points",
    if_attention_map=True,
    device=DEVICE,
)

In [9]:
batch = dataset_test[0]
print(batch["image_name"])
print(batch["slices"])


['7_t1.nii.gz', '7_t2_SPACE.nii.gz']
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102]


In [60]:
dice_score_batch = dice_coeff(
                    low_res_masks_bool[:, 1, :, :].to(torch.float32),
                    batch["masks"][:, 0, :, :].to(torch.float32),
                ).item()

In [8]:
def run_make_roc(dataset, net, save_dir: Path, thresholds=[0]):
    result_dict = {thresh: {} for thresh in thresholds}
    with tqdm(
        total=len(dataset) * len(thresholds),
        desc=f"Running inference on queued batches of size",
        unit="Batch",
    ) as pbar:
        for ind in range(len(dataset)):
            data = dataset[ind]
            batch_size = data["images"].shape[0]
            pbar.set_description(
                f"Running inference on queued batches of size {batch_size}"
            )
            pbar.refresh()

            with torch.no_grad():
                low_res_masks = net.forward(
                    data, multimask_output=True, if_attention=True
                )
            # low_res_masks_bool = low_res_masks > net.mask_threshold
            for thresh in thresholds:
                low_res_masks_bool = low_res_masks > thresh
                cat_indexes = data.get("cat_indexes", [0])
                for j in range(len(cat_indexes)):
                    orig_size = data["original_size"][j]
                    img_name = data["image_name"][j]
                    if len(cat_indexes) > j+1:
                        slices = data["slices"][cat_indexes[j] : cat_indexes[j + 1]]
                        sliced_mask = data["masks"][cat_indexes[j] : cat_indexes[j + 1]]
                        sliced_prediction = low_res_masks_bool[
                            cat_indexes[j] : cat_indexes[j + 1]
                        ]
                    else:
                        slices = data["slices"][cat_indexes[j] :]
                        sliced_mask = data["masks"][cat_indexes[j] :]
                        sliced_prediction = low_res_masks_bool[cat_indexes[j] :]

                    dice_score_img = dice_coeff(
                        sliced_prediction[:, 1, :, :].to(torch.float32),
                        sliced_mask[:, 0, :, :].to(torch.float32),
                    ).item()
                    with torch.no_grad():
                        dice_ce = DiceCELoss(
                            sigmoid=True, squared_pred=True, reduction="mean"
                        )(
                            sliced_prediction[:, 1:2, :, :].to(torch.float32),
                            sliced_mask.to(torch.float32),
                        ).item()
                    sensitivity, specificity = calculate_sensitivity_specificity(
                        sliced_prediction[:, 1:2, :, :], sliced_mask, class_index=1
                    )
                    iou_batch = iou_torch(
                        sliced_prediction[:, 1:2, :, :], sliced_mask.to(torch.int)
                    )

                    file_path = (  # not really saved. to indicate which slices are in
                        save_dir
                        / f"{img_name.split(".")[0]}-{slices[0]}-{slices[-1]}.nrrd"
                    )
                    if img_name in result_dict.keys():
                        result_dict[thresh][img_name]["prediction_path"].append(
                            str(file_path)
                        )
                        result_dict[thresh][img_name]["sensitivity"].append(sensitivity)
                        result_dict[thresh][img_name]["specificity"].append(specificity)
                        result_dict[thresh][img_name]["iou"].append(iou_batch)
                        result_dict[thresh][img_name]["dice_score"].append(
                            dice_score_img
                        )
                        result_dict[thresh][img_name]["dice_ce"].append(dice_ce)
                    else:
                        result_dict[thresh][img_name] = {
                            "prediction_path": [str(file_path)],
                            "sensitivity": [sensitivity],
                            "specificity": [specificity],
                            "iou": [iou_batch],
                            "dice_score": [dice_score_img],
                            "dice_ce": [dice_ce],
                        }

                pbar.update(1)

    return result_dict

In [9]:
"""ROC calculations"""

model_identifier = "mevis_sam-epoch151-29334c5d6e2b41bf86e613599ffeaff3"
eval_results_base = Path("./eval_results")
predicted_msk_folder = Path("/data/sab_data/predicted_masks/v2")
eval_results = eval_results_base / (model_identifier + "/roc")
eval_results.mkdir(parents=True, exist_ok=True)

save_dir = predicted_msk_folder / model_identifier
save_dir.mkdir(parents=True, exist_ok=True)
thresholds = np.arange(-5, 1.2, 1).tolist()

test_inference = run_make_roc(
    dataset_test, model, save_dir=save_dir, thresholds=thresholds
)
with open(
    eval_results / f"Test_inference_results_roc_{merged_args.prompt_probability}.json",
    "w",
) as f:
    json.dump(
        test_inference,
        f,
        indent=4,
        sort_keys=False,
        separators=(",", ": "),
    )

validation_inference = run_make_roc(
    dataset_validation, model, save_dir=save_dir, thresholds=thresholds
)
with open(
    eval_results
    / f"Validation_inference_results_roc_{merged_args.prompt_probability}.json",
    "w",
) as f:
    json.dump(
        validation_inference,
        f,
        indent=4,
        sort_keys=False,
        separators=(",", ": "),
    )

train_inference = run_make_roc(
    dataset_train, model, save_dir=save_dir, thresholds=thresholds
)
with open(
    eval_results / f"Train_inference_results_roc_{merged_args.prompt_probability}.json",
    "w",
) as f:
    json.dump(
        train_inference,
        f,
        indent=4,
        sort_keys=False,
        separators=(",", ": "),
    )

Running inference on queued batches of size:   0%|          | 0/84 [00:00<?, ?Batch/s]

Running inference on queued batches of size:   0%|          | 0/84 [00:00<?, ?Batch/s]

Running inference on queued batches of size:   0%|          | 0/658 [00:00<?, ?Batch/s]

In [8]:
def run_save_inference(dataset, net, best_thresh: float, save_dir: Path, save=True):
    # best_thresh = -1.0
    result_dict = {}
    with tqdm(
        total=len(dataset),
        desc=f"Running inference on queued batches of size",
        unit="Batch",
    ) as pbar:
        for ind in range(len(dataset)):
            data = dataset[ind]
            batch_size = data["images"].shape[0]
            pbar.set_description(
                f"Running inference on queued batches of size {batch_size}"
            )
            pbar.refresh()

            with torch.no_grad():
                low_res_masks = net.forward(
                    data, multimask_output=True, if_attention=True
                )
            low_res_masks_bool = low_res_masks > best_thresh
            cat_indexes = data.get("cat_indexes", [0])
            for j in range(len(cat_indexes)):
                orig_size = data["original_size"][j]
                img_name = data["image_name"][j]
                if len(cat_indexes) > j+1:
                    slices = data["slices"][cat_indexes[j] : cat_indexes[j + 1]]
                    sliced_mask = data["masks"][cat_indexes[j] : cat_indexes[j + 1]]
                    sliced_prediction = low_res_masks_bool[
                        cat_indexes[j] : cat_indexes[j + 1]
                    ]
                else:
                    slices = data["slices"][cat_indexes[j] :]
                    sliced_mask = data["masks"][cat_indexes[j] :]
                    sliced_prediction = low_res_masks_bool[cat_indexes[j] :]

                dice_score_img = dice_coeff(
                    sliced_prediction[:, 1, :, :].to(torch.float32),
                    sliced_mask[:, 0, :, :].to(torch.float32),
                ).item()
                with torch.no_grad():
                    dice_ce = DiceCELoss(
                        sigmoid=True, squared_pred=True, reduction="mean"
                    )(
                        sliced_prediction[:, 1:2, :, :].to(torch.float32),
                        sliced_mask.to(torch.float32),
                    ).item()
                sensitivity, specificity = calculate_sensitivity_specificity(
                    sliced_prediction[:, 1:2, :, :], sliced_mask, class_index=1
                )
                iou_batch = iou_torch(
                    sliced_prediction[:, 1:2, :, :], sliced_mask.to(torch.int)
                )

                file_path = (  # not really saved. to indicate which slices are in
                    save_dir
                    / f"{img_name.split(".")[0]}-{slices[0]}-{slices[-1]}.nrrd"
                )
                orig_size_masks = net.postprocess_masks(
                    masks=sliced_prediction.to(torch.float32),
                    input_size=(1024, 1024),
                    original_size=orig_size,
                ).squeeze(1)

                if img_name in result_dict.keys():
                    result_dict[img_name]["prediction_path"].append(str(file_path))
                    result_dict[img_name]["sensitivity"].append(sensitivity)
                    result_dict[img_name]["specificity"].append(specificity)
                    result_dict[img_name]["iou"].append(iou_batch)
                    result_dict[img_name]["dice_score"].append(dice_score_img)
                    result_dict[img_name]["dice_ce"].append(dice_ce)
                else:
                    result_dict[img_name] = {
                        "prediction_path": [str(file_path)],
                        "sensitivity": [sensitivity],
                        "specificity": [specificity],
                        "iou": [iou_batch],
                        "dice_score": [dice_score_img],
                        "dice_ce": [dice_ce],
                    }
                if save:
                    nrrd_vol = sitk.GetImageFromArray(
                        orig_size_masks.to(torch.uint8).cpu().numpy()[:, 1, :, :]
                    )
                    sitk.WriteImage(nrrd_vol, fileName=file_path)
            pbar.update(1)

    return result_dict

In [9]:
best_thresh = -1.0
model_identifier = "mevis_sam-epoch151-29334c5d6e2b41bf86e613599ffeaff3"
eval_results_base = Path("./eval_results")
predicted_msk_folder = Path("/data/sab_data/predicted_masks/v2")
eval_results = eval_results_base / model_identifier
eval_results.mkdir(parents=True, exist_ok=True)

save_dir = predicted_msk_folder / model_identifier
save_dir.mkdir(parents=True, exist_ok=True)
save = True

test_inference = run_save_inference(
    dataset_test, model, best_thresh=best_thresh, save_dir=save_dir, save=save
)
with open(
    eval_results
    / f"Test_inference_results_{merged_args.prompt_probability}_thresh{best_thresh}.json",
    "w",
) as f:
    json.dump(
        test_inference,
        f,
        indent=4,
        sort_keys=False,
        separators=(",", ": "),
    )

validation_inference = run_save_inference(
    dataset_validation, model, best_thresh=best_thresh, save_dir=save_dir, save=save
)
with open(
    eval_results
    / f"Validation_inference_results_{merged_args.prompt_probability}_thresh{best_thresh}.json",
    "w",
) as f:
    json.dump(
        validation_inference,
        f,
        indent=4,
        sort_keys=False,
        separators=(",", ": "),
    )

train_inference = run_save_inference(
    dataset_train, model, best_thresh=best_thresh, save_dir=save_dir, save=save
)
with open(
    eval_results
    / f"Train_inference_results_{merged_args.prompt_probability}_thresh{best_thresh}.json",
    "w",
) as f:
    json.dump(
        train_inference,
        f,
        indent=4,
        sort_keys=False,
        separators=(",", ": "),
    )

Running inference on queued batches of size:   0%|          | 0/12 [00:00<?, ?Batch/s]

Running inference on queued batches of size:   0%|          | 0/12 [00:00<?, ?Batch/s]

Running inference on queued batches of size:   0%|          | 0/94 [00:00<?, ?Batch/s]