In [30]:
import numpy as np
import pandas as pd
import ml_collections
from tqdm import tqdm
import tensorflow as tf

import os
import shutil
import importlib
from typing import Callable, Optional

from models import vanilla
from utils import plotting
from metrics import compute_segmentation_metrics
from utils.datasets.foscal.patient import FOSCALPatient
from utils.config import get_path_of_directory_with_id, load_yaml_config
from utils.preprocessing.numpy import binarize_array
from utils.preprocessing.tensorflow import resize_data, resize_mask

# Helpers

In [13]:
def import_method_from_config(
    config: ml_collections.ConfigDict, attr: str
) -> Optional[Callable]:
    """Imports a method from a config attribute."""
    if attr in config:
        return import_method(config[attr])

def import_method(method_path: str) -> Callable:
    """Import a method from a module path."""
    method_shards = method_path.split(".")
    method_shards[0] = {
        "np": "numpy",
        "tf": "tensorflow",
        "tfa": "tensorflow_addons",
    }.get(method_shards[0], method_shards[0])

    module_path = ".".join(method_shards[:-1])
    method_name = method_shards[-1]

    module = importlib.import_module(module_path)
    return getattr(module, method_name)

# Evaluation

In [50]:
def evaluate_dual_unet_on_foscal_dataset(model, config, subdir):
    
    if subdir is None:
        save_dir = config.evaluation_dir
    else:
        save_dir = os.path.join(config.evaluation_dir, subdir)


    input_shape = tuple(int(x) for x in config.model.input_shape.split(","))
    modalities = config.dataloader.modalities.split(",")

    for radiologist in ["Daniel", "Andres"]:
        
        valid_dir = os.path.join(save_dir, "valid", radiologist)
        os.makedirs(valid_dir, exist_ok=True)

        metric_names = ["sens", "spec", "ppv", "npv", "dsc", "avd", "hd"]
        adc_results = {m: [] for m in metric_names}
        adc_results["patient"] = []
        dwi_results = {m: [] for m in metric_names}
        dwi_results["patient"] = []

        patient_paths = np.loadtxt(config.dataloader.valid_patients_path, dtype=str)
        for patient_path in patient_paths:
            patient = FOSCALPatient(str(patient_path))
            original_shape = (patient.original_shape[0], patient.original_shape[1])
            patient_dir = os.path.join(valid_dir, patient.patient_id)
            os.makedirs(patient_dir, exist_ok=True)

            data_dict = patient.get_data(modalities, normalization="min_max")
            data = {
                k: np.expand_dims(v.transpose(2, 0, 1), -1) for k, v in data_dict.items()
            }
            data = np.concatenate(list(data.values()), axis=-1)
            resized_data = resize_data(data, input_shape[:2])
            resized_data = (
                resized_data[..., 0:1],
                resized_data[..., 1:],
            )

            # Predict the lesion. Keep the last output if the model is deeply supervised.
            adc_probabilities, dwi_probabilities = model.predict(resized_data)

            adc_pred_mask = binarize_array(adc_probabilities, threshold=0.5)
            adc_pred_mask = adc_pred_mask[..., 0].transpose(1, 2, 0)
            adc_resized_pred_mask = resize_mask(adc_pred_mask, original_shape).numpy()

            dwi_pred_mask = binarize_array(dwi_probabilities, threshold=0.5)
            dwi_pred_mask = dwi_pred_mask[..., 0].transpose(1, 2, 0)
            dwi_resized_pred_mask = resize_mask(dwi_pred_mask, original_shape).numpy()

            # Save each prediction and mask.
            niftis_dir = os.path.join(patient_dir, "niftis")
            os.makedirs(niftis_dir, exist_ok=True)
            shutil.copy(getattr(patient, "adc_daniel_mask_path"), niftis_dir)
            shutil.copy(getattr(patient, "dwi_daniel_mask_path"), niftis_dir)

            # Compute metrics.
            adc_mask = patient.get_mask(modalities=modalities, radiologist=radiologist)[
                "ADC"
            ]
            adc_vol_metrics = compute_segmentation_metrics(adc_mask, adc_resized_pred_mask)
            adc_results["patient"].append(patient.patient_id)
            for m in metric_names:
                adc_results[m].append(adc_vol_metrics[m])

            # Create the titles and append the arrays for creating plots later.
            title = plotting.create_title_with_metrics(**adc_vol_metrics)
            overlapped_ots_path = os.path.join(
                patient_dir, "adc_data_and_ots_overlapped.png"
            )
            animation_save_path = os.path.join(patient_dir, f"adc_animation.gif")
            plotting.plot_data_with_overlapping_ots(
                data_dict["ADC"],
                adc_mask,
                adc_resized_pred_mask,
                save_path=overlapped_ots_path,
                show_plot=False,
                title=title,
            )
            plotting.save_animated_data_with_overlapping_ots(
                data_dict["ADC"],
                adc_mask,
                adc_resized_pred_mask,
                save_path=animation_save_path,
                titles=title,
            )

            dwi_mask = patient.get_mask(modalities=modalities, radiologist=radiologist)[
                "DWI"
            ]
            dwi_vol_metrics = compute_segmentation_metrics(dwi_mask, dwi_resized_pred_mask)
            dwi_results["patient"].append(patient.patient_id)
            for m in metric_names:
                dwi_results[m].append(dwi_vol_metrics[m])

            # Create the titles and append the arrays for creating plots later.
            title = plotting.create_title_with_metrics(**dwi_vol_metrics)
            overlapped_ots_path = os.path.join(
                patient_dir, "dwi_data_and_ots_overlapped.png"
            )
            animation_save_path = os.path.join(patient_dir, f"dwi_animation.gif")
            plotting.plot_data_with_overlapping_ots(
                data_dict["DWI"],
                dwi_mask,
                dwi_resized_pred_mask,
                save_path=overlapped_ots_path,
                show_plot=False,
                title=title,
            )
            plotting.save_animated_data_with_overlapping_ots(
                data_dict["DWI"],
                dwi_mask,
                dwi_resized_pred_mask,
                save_path=animation_save_path,
                titles=title,
            )

        if len(adc_results["patient"]) != 0:
            metrics_per_patient = pd.DataFrame(adc_results)
            metrics_per_patient = metrics_per_patient.set_index("patient")
            metrics_per_patient.loc["mean"] = metrics_per_patient.mean()
            metrics_per_patient.loc["std"] = metrics_per_patient.std()
            metrics_per_patient.to_csv(os.path.join(valid_dir, "adc_patient_metrics.csv"))

        if len(dwi_results["patient"]) != 0:
            metrics_per_patient = pd.DataFrame(dwi_results)
            metrics_per_patient = metrics_per_patient.set_index("patient")
            metrics_per_patient.loc["mean"] = metrics_per_patient.mean()
            metrics_per_patient.loc["std"] = metrics_per_patient.std()
            metrics_per_patient.to_csv(os.path.join(valid_dir, "dwi_patient_metrics.csv"))

def evaluate_unet_on_foscal_dataset(model, config, subdir):
    
    if subdir is None:
        save_dir = config.evaluation_dir
    else:
        save_dir = os.path.join(config.evaluation_dir, subdir)


    input_shape = tuple(int(x) for x in config.model.input_shape.split(","))
    modalities = config.dataloader.modalities.split(",")

    for radiologist in ["Daniel", "Andres"]:
        
        valid_dir = os.path.join(save_dir, "valid", radiologist)
        os.makedirs(valid_dir, exist_ok=True)

        metric_names = ["sens", "spec", "ppv", "npv", "dsc", "avd", "hd"]
        results = {m: [] for m in metric_names}
        results["patient"] = []
        patient_paths = np.loadtxt(config.dataloader.valid_patients_path, dtype=str)
        for patient_path in patient_paths:
            patient = FOSCALPatient(str(patient_path))
            original_shape = (patient.original_shape[0], patient.original_shape[1])
            patient_dir = os.path.join(valid_dir, patient.patient_id)
            os.makedirs(patient_dir, exist_ok=True)

            data_dict = patient.get_data(modalities, normalization="min_max")
            data = {
                k: np.expand_dims(v.transpose(2, 0, 1), -1) for k, v in data_dict.items()
            }
            data = np.concatenate(list(data.values()), axis=-1)
            resized_data = resize_data(data, input_shape[:2])

            # Predict the lesion. Keep the last output if the model is deeply supervised.
            probabilities = model.predict(resized_data)
            if isinstance(probabilities, (list, tuple)):
                probabilities = probabilities[-1]
            pred_mask = binarize_array(probabilities, threshold=0.5)
            pred_mask = pred_mask[..., 0].transpose(1, 2, 0)
            resized_pred_mask = resize_mask(pred_mask, original_shape).numpy()

            # Save each prediction and mask.
            niftis_dir = os.path.join(patient_dir, "niftis")
            os.makedirs(niftis_dir, exist_ok=True)

            mask_attr = modalities[0].lower()
            if hasattr(patient, mask_attr):
                shutil.copy(getattr(patient, mask_attr + "_path"), niftis_dir)

                # Compute metrics.
                mask = patient.get_mask(modalities=modalities, radiologist=radiologist)[
                    modalities[0]
                ]
                vol_metrics = compute_segmentation_metrics(mask, resized_pred_mask)
                results["patient"].append(patient.patient_id)
                for m in metric_names:
                    results[m].append(vol_metrics[m])

                # Create the titles and append the arrays for creating plots later.
                title = plotting.create_title_with_metrics(**vol_metrics)
                key_list = list(data_dict.keys())
                first_key = key_list[0]

                overlapped_ots_path = os.path.join(
                    patient_dir, "data_and_ots_overlapped.png"
                )
                animation_save_path = os.path.join(patient_dir, f"animation.gif")

                plotting.plot_data_with_overlapping_ots(
                    data_dict[first_key],
                    mask,
                    resized_pred_mask,
                    save_path=overlapped_ots_path,
                    show_plot=False,
                    title=title,
                )

                plotting.save_animated_data_with_overlapping_ots(
                    data_dict[first_key],
                    mask,
                    resized_pred_mask,
                    save_path=animation_save_path,
                    titles=title,
                )

        if len(results["patient"]) != 0:
            metrics_per_patient = pd.DataFrame(results)
            metrics_per_patient = metrics_per_patient.set_index("patient")
            metrics_per_patient.loc["mean"] = metrics_per_patient.mean()
            metrics_per_patient.loc["std"] = metrics_per_patient.std()
            metrics_per_patient.to_csv(os.path.join(valid_dir, "patient_metrics.csv"))

In [51]:
def evaluate_experiment_best_model(experiment_id):
    
    tf.keras.backend.clear_session()

    experiment_dir = get_path_of_directory_with_id(experiment_id)

    # Load config files.
    config = load_yaml_config(os.path.join(experiment_dir, "config", "config.yml"))
    config = ml_collections.ConfigDict(config)

    # Create the model and load the best weights.
    model_config = config.model.to_dict()
    model_config["input_shape"] = tuple(int(s) for s in config.model.input_shape.split(","))
    model_config["filters_per_level"] = [int(f) for f in config.model.filters_per_level.split(",")]
    model_config["blocks_depth"] = [int(f) for f in config.model.blocks_depth.split(",")]
    model_config["norm_layer"] = import_method_from_config(config.model, "norm_layer")
    model_config["upsample_layer"] = import_method_from_config(config.model, "upsample_layer")
    model_config["attention_layer"] = import_method_from_config(config.model, "attention_layer")
    model_config["pooling_layer"] = import_method_from_config(config.model, "pooling_layer")

    # Iterate over the test patients and compute the metrics for both radiologists.
    if experiment_id > 9:
        model = vanilla.DualUnet(**model_config)
        model.load_weights(config.best_weights_path).expect_partial()
        evaluate_dual_unet_on_foscal_dataset(model, config, "best")
    else:
        encoder = vanilla.UNetEncoder(**model_config)
        skip_names = vanilla.get_skip_names_from_encoder(encoder)
        model = vanilla.UNet(encoder, skip_names, **model_config)
        hidden_layer_names = vanilla.get_output_names_for_deep_supervision(model)
        model = vanilla.add_deep_supervision_to_unet(model, hidden_layer_names)
        model.load_weights(config.best_weights_path).expect_partial()
        evaluate_unet_on_foscal_dataset(model, config, "best")

Evaluate the baseline UNets

In [52]:
for experiment_id in tqdm(range(10)):
    evaluate_experiment_best_model(experiment_id)

100%|██████████| 10/10 [16:43<00:00, 100.40s/it]


Evaluate the Dual UNets

In [53]:
for experiment_id in tqdm(range(10, 15)):
    evaluate_experiment_best_model(experiment_id)

100%|██████████| 5/5 [16:02<00:00, 192.54s/it]


# Predict for figures

In [73]:
def dual_predict_on_foscal_patients(model, config, patient_paths, experiment_id):
    input_shape = tuple(int(x) for x in config.model.input_shape.split(","))
    modalities = config.dataloader.modalities.split(",")
        
    for patient_path in patient_paths:
        patient_dir = os.path.join("figs", "results", f"{experiment_id}", os.path.basename(patient_path))
        os.makedirs(patient_dir, exist_ok=True)

        patient = FOSCALPatient(str(patient_path))
        original_shape = (patient.original_shape[0], patient.original_shape[1])

        data_dict = patient.get_data(modalities, normalization="min_max")
        data = {
            k: np.expand_dims(v.transpose(2, 0, 1), -1) for k, v in data_dict.items()
        }
        data = np.concatenate(list(data.values()), axis=-1)
        resized_data = resize_data(data, input_shape[:2])
        resized_data = (
            resized_data[..., 0:1],
            resized_data[..., 1:],
        )

        # Predict the lesion. Keep the last output if the model is deeply supervised.
        adc_probabilities, dwi_probabilities = model.predict(resized_data)

        adc_pred_mask = binarize_array(adc_probabilities, threshold=0.5)
        adc_pred_mask = adc_pred_mask[..., 0].transpose(1, 2, 0)
        adc_resized_pred_mask = resize_mask(adc_pred_mask, original_shape).numpy()

        dwi_pred_mask = binarize_array(dwi_probabilities, threshold=0.5)
        dwi_pred_mask = dwi_pred_mask[..., 0].transpose(1, 2, 0)
        dwi_resized_pred_mask = resize_mask(dwi_pred_mask, original_shape).numpy()

        # Get the annotations for each radiologist.
        daniel_masks = patient.get_mask(modalities=modalities, radiologist="Daniel")
        andres_masks = patient.get_mask(modalities=modalities, radiologist="Andres")
        daniel_adc_mask, daniel_dwi_mask = daniel_masks["ADC"], daniel_masks["DWI"]
        andres_adc_mask, andres_dwi_mask = andres_masks["ADC"], andres_masks["DWI"]

        # Save resources as npy.
        np.save(os.path.join(patient_dir, "adc.npy"), data[..., 0].transpose(1, 2, 0))
        np.save(os.path.join(patient_dir, "adc_daniel_mask.npy"), daniel_adc_mask)
        np.save(os.path.join(patient_dir, "adc_andres_mask.npy"), andres_adc_mask)
        np.save(os.path.join(patient_dir, "adc_resized_pred_mask.npy"), adc_resized_pred_mask)
        np.save(os.path.join(patient_dir, "dwi.npy"), data[..., 1].transpose(1, 2, 0))
        np.save(os.path.join(patient_dir, "dwi_daniel_mask.npy"), daniel_dwi_mask)
        np.save(os.path.join(patient_dir, "dwi_andres_mask.npy"), andres_dwi_mask)
        np.save(os.path.join(patient_dir, "dwi_resized_pred_mask.npy"), dwi_resized_pred_mask)

def predict_on_foscal_patients(model, config, patient_paths, experiment_id):
    input_shape = tuple(int(x) for x in config.model.input_shape.split(","))
    modalities = config.dataloader.modalities.split(",")
    assert len(modalities) == 1
        
    for patient_path in patient_paths:
        patient_dir = os.path.join("figs", "results", f"{experiment_id}", os.path.basename(patient_path))
        os.makedirs(patient_dir, exist_ok=True)

        patient = FOSCALPatient(str(patient_path))
        original_shape = (patient.original_shape[0], patient.original_shape[1])

        data_dict = patient.get_data(modalities, normalization="min_max")
        data = {
            k: np.expand_dims(v.transpose(2, 0, 1), -1) for k, v in data_dict.items()
        }
        data = np.concatenate(list(data.values()), axis=-1)
        resized_data = resize_data(data, input_shape[:2])

        # Predict the lesion. Keep the last output if the model is deeply supervised.
        probabilities = model.predict(resized_data)
        if isinstance(probabilities, (list, tuple)):
            probabilities = probabilities[-1]
        pred_mask = binarize_array(probabilities, threshold=0.5)
        pred_mask = pred_mask[..., 0].transpose(1, 2, 0)
        resized_pred_mask = resize_mask(pred_mask, original_shape).numpy()

        # Get the annotations for each radiologist.
        modality = modalities[0]
        daniel_masks = patient.get_mask(modalities=modalities, radiologist="Daniel")
        andres_masks = patient.get_mask(modalities=modalities, radiologist="Andres")
        daniel_mask = daniel_masks[modality]
        andres_mask = andres_masks[modality]

        # Save resources as npy.
        np.save(os.path.join(patient_dir, f"{modality.lower()}.npy"), data[..., 0].transpose(1, 2, 0))
        np.save(os.path.join(patient_dir, f"{modality.lower()}_daniel_mask.npy"), daniel_mask)
        np.save(os.path.join(patient_dir, f"{modality.lower()}_andres_mask.npy"), andres_mask)
        np.save(os.path.join(patient_dir, f"{modality.lower()}_resized_pred_mask.npy"), resized_pred_mask)

In [74]:
# paciente 021, 034, experimento 13 (fold 4)
def dual_predict_for_experiment(experiment_id, patient_paths):
    
    tf.keras.backend.clear_session()

    experiment_dir = get_path_of_directory_with_id(experiment_id)

    # Load config files.
    config = load_yaml_config(os.path.join(experiment_dir, "config", "config.yml"))
    config = ml_collections.ConfigDict(config)

    # Create the model and load the best weights.
    model_config = config.model.to_dict()
    model_config["input_shape"] = tuple(int(s) for s in config.model.input_shape.split(","))
    model_config["filters_per_level"] = [int(f) for f in config.model.filters_per_level.split(",")]
    model_config["blocks_depth"] = [int(f) for f in config.model.blocks_depth.split(",")]
    model_config["norm_layer"] = import_method_from_config(config.model, "norm_layer")
    model_config["upsample_layer"] = import_method_from_config(config.model, "upsample_layer")
    model_config["attention_layer"] = import_method_from_config(config.model, "attention_layer")
    model_config["pooling_layer"] = import_method_from_config(config.model, "pooling_layer")

    # Iterate over the test patients and compute the metrics for both radiologists.
    if experiment_id > 9:
        model = vanilla.DualUnet(**model_config)
        model.load_weights(config.best_weights_path).expect_partial()
        dual_predict_on_foscal_patients(model, config, patient_paths, experiment_id)
    else:
        encoder = vanilla.UNetEncoder(**model_config)
        skip_names = vanilla.get_skip_names_from_encoder(encoder)
        model = vanilla.UNet(encoder, skip_names, **model_config)
        hidden_layer_names = vanilla.get_output_names_for_deep_supervision(model)
        model = vanilla.add_deep_supervision_to_unet(model, hidden_layer_names)
        model.load_weights(config.best_weights_path).expect_partial()
        predict_on_foscal_patients(model, config, patient_paths, experiment_id)

In [75]:
patient_paths = [
    "/data/Datasets/stroke/ISBI_FOSCAL/ACV-034", 
    "/data/Datasets/stroke/ISBI_FOSCAL/ACV-021", 
    "/data/Datasets/stroke/ISBI_FOSCAL/ACV-031"
]

for experiment_id in tqdm([6, 7, 13]):
    dual_predict_for_experiment(experiment_id, patient_paths)