In [None]:
import pathlib
import sys

import sklearn.metrics as skm
import tensorflow as tf
from keras.utils.layer_utils import count_params
from tqdm import tqdm

sys.path.append(pathlib.Path.cwd().parent.as_posix())

from src.data.preprocess.lib.utils import get_patient_split
from src.models.lib.builder import build_unet_pp
from src.models.lib.config import UNetPPConfig
from src.models.lib.data_loader import preprocess_img
from src.models.lib.loss import dice_coef, log_cosh_dice_loss
from src.models.lib.utils import loss_dict_gen
from src.system.pipeline.output import auto_cac, ground_truth_auto_cac

project_root_path = pathlib.Path.cwd().parent

## Model Testing

### Main Model Import

In [None]:
# Select model
model_root_path = project_root_path / "models" / "path"  # Change this
model_paths = list((model_root_path).rglob("*model*"))
model_paths

In [None]:
# Import main model
selected_model_path = model_paths[0].as_posix()
loss_func = log_cosh_dice_loss

main_model = tf.keras.models.load_model(
    selected_model_path,
    custom_objects={
        "log_cosh_dice_loss": loss_func,
        "dice_coef": dice_coef,
    },
)

### Pruned Model Creation

In [None]:
model_depth = 5
filter_list = [16, 32, 64, 128, 256]
downsample_iteration = [4, 3, 2, 2, 1]


pruned_model = {}


for depth in range(1, model_depth):
    pruned_model[f"d{depth}"] = {}

    model_config = UNetPPConfig(
        model_name=f"model_d{depth}",
        upsample_mode="upsample",
        depth=depth + 1,
        input_dim=[512, 512, 1],
        batch_norm=True,
        deep_supervision=True,
        model_mode="mobile",
        n_class={"bin": 1},
        filter_list=filter_list[: depth + 1],
        downsample_iteration=downsample_iteration[: depth + 1],
    )

    model, output_layer_name = build_unet_pp(model_config, custom=True)
    if depth != model_depth - 1:
        print(f"-- Creating pruned model d{depth}")
        for layer in tqdm(model.layers):
            pruned_layer_name = layer.name

            main_model_layer = main_model.get_layer(pruned_layer_name)

            main_model_weight = main_model_layer.get_weights()

            layer.set_weights(main_model_weight)

        pruned_model[f"d{depth}"]["model"] = model

    loss_dict = loss_dict_gen(model_config, output_layer_name, [loss_func])

    pruned_model[f"d{depth}"]["dataset"] = create_dataset(
        project_root_path, model_config, 2, 2048
    )

    pruned_model[f"d{depth}"]["config"] = model_config
    pruned_model[f"d{depth}"]["loss_dict"] = loss_dict

pruned_model[f"d{depth-1}"]["model"] = main_model

### Pruned Model Compilation

In [None]:
for depth in range(1, model_depth):
    metrics = [
        dice_coef,
        tf.keras.metrics.BinaryIoU(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision(),
        tf.keras.metrics.TruePositives(),
        tf.keras.metrics.TrueNegatives(),
        tf.keras.metrics.FalseNegatives(),
        tf.keras.metrics.FalsePositives(),
    ]

    pruned_model[f"d{depth}"]["model"] = pruned_model[f"d{depth}"]["model"].compile(
        optimizer=tf.keras.optimizers.legacy.Adam(),
        loss=pruned_model[f"d{depth}"]["loss_dict"],
        metrics=metrics,
    )
    pruned_model[f"d{depth}"]["trainable_weights"] = count_params(
        pruned_model[f"d{depth}"]["model"].trainable_weights
    )
    pruned_model[f"d{depth}"]["non_trainable_weights"] = count_params(
        pruned_model[f"d{depth}"]["model"].non_trainable_weights
    )
    pruned_model[f"d{depth}"]["weights"] = count_params(
        pruned_model[f"d{depth}"]["model"].weights
    )

### Evaluate Model (Quantitative)

In [None]:
for depth in range(1, model_depth):
    print(f"Evaluate model d{depth} on test dataset")
    print(pruned_model[f"d{depth}"])
    pruned_model[f"d{depth}"]["model"].evaluate(
        pruned_model[f"d{depth}"]["model"]["test"]
    )

### Evaluate Model (Qualitative)

In [None]:
# Ground Truth
patient_idx = 0
patient_img_idx = 0
bin_json_path = list(project_root_path.rglob("bin*.json"))[0]

with bin_json_path.open(mode="r") as json_file:
    bin_dict_output = json.load(json_file)

patient_info = bin_dict_output[str(patient_idx).zfill(3)]
patient_img_info = patient_info[patient_img_idx]


# Create lesion mask
patient_img_lesion = np.zeros((512, 512))
patient_img_lesion[tuple(zip(*patient_img_info["pos"]))] = 1
patient_img_lesion = np.transpose(patient_img_lesion)

# Get patient img
patient_img_num = patient_img_info["idx"]

print(
    f"Patient {patient_idx} Image {patient_img_idx+1}/{len(patient_info)} ({patient_img_num})"
)

patient_root_path = next(project_root_path.rglob(f"patient/{patient_idx}"))
patient_dcm_path = next(
    patient_root_path.rglob(f"*00{str(patient_img_num).lstrip('0').zfill(2)}.dcm")
)
print(patient_dcm_path)
patient_dcm = pdc.dcmread(patient_dcm_path)
patient_img_arr = patient_dcm.pixel_array
patient_img_hu = pdc.pixel_data_handlers.util.apply_modality_lut(
    patient_img_arr, patient_dcm
)
patient_img_hu_pre = preprocess_img(patient_img_hu)


# Plot
fig, ax = plt.subplots(1, 3, figsize=(20, 60))

ax[0].set_title("Image")
ax[0].axis("off")
ax[0].imshow(patient_img_hu, cmap="gray", interpolation="none")

ax[1].set_title("Binary Segment")
ax[1].imshow(np.ones([512, 512]), cmap="gray")
ax[1].axis("off")
ax[1].imshow(patient_img_lesion, cmap="gray")

ax[2].set_title("ROI Overlay")
ax[2].axis("off")
ax[2].imshow(patient_img_hu_pre, cmap="gray", interpolation="none")
ax[2].imshow(patient_img_lesion, cmap="gray", alpha=0.5)

plt.show()

In [None]:
img_model_input = np.expand_dims(np.expand_dims(patient_img_hu_pre, axis=0), axis=3)

In [None]:
# Model Output
for depth in range(1, model_depth):
    model_out = pruned_model[f"d{depth}"]["model"].predict(img_model_input)
    model_bin_seg = (np.squeeze(model_out) > 0.5) * 1
    fig, ax = plt.subplots(1, 3, figsize=(20, 40))

    ax[0].set_title("Image")
    ax[0].axis("off")
    ax[0].imshow(patient_img_hu, cmap="gray", interpolation="none")

    ax[1].set_title("Ground Truth Segmentation")
    ax[1].axis("off")
    ax[1].imshow(patient_img_hu, cmap="gray", interpolation="none")
    ax[1].imshow(patient_img_lesion, cmap="gray", alpha=0.5)

    ax[2].set_title(f"Model Segmentation Depth {depth}")
    ax[2].axis("off")
    ax[2].imshow(patient_img_hu, cmap="gray", interpolation="none")
    ax[2].imshow(model_bin_seg, cmap="gray", alpha=0.5)
    plt.show()

## System Testing

### Prepare Ground Truth System Output

In [None]:
bin_json_path = list(project_root_path.rglob("bin*.json"))[0]

with bin_json_path.open(mode="r") as json_file:
    bin_dict_output = json.load(json_file)

patient_test_data = set(get_patient_split([0.7, 0.2, 0.1])["test"])
patient_with_segment = set(bin_dict_output.keys())

# Split patient with segmentation and no segmentation
patient_test_with_segment = patient_with_segment.intersection(patient_test_data)
patient_test_no_segment = patient_test_data.difference(patient_with_segment)

In [None]:
ground_test_data_dict = {}

# Add ground truth data for patient without segmentation
for idx_no_seg in patient_test_no_segment:
    ground_test_data_dict[idx_no_seg] = {}
    ground_test_data_dict[idx_no_seg]["total_agatston"] = 0
    ground_test_data_dict[idx_no_seg]["class"] = "Absent"

In [None]:
for idx_seg in tqdm(patient_test_with_segment):
    patient_root_path = next(project_root_path.rglob(f"patient/{idx_seg.lstrip('0')}"))

    img_path = [
        next(
            patient_root_path.rglob(
                f"*00{str(int(x['idx'].lstrip('0'))+0).zfill(2)}.dcm"
            )
        )
        for x in bin_dict_output[idx_seg]
    ]
    loc_list = [x["pos"] for x in bin_dict_output[idx_seg]]

    ground_test_data_dict[idx_seg] = ground_truth_auto_cac(
        img_path, loc_list, mem_opt=True
    )

### Get Output from Pruned Model

In [None]:
model_test_data_dict = {}

for idx_seg in tqdm(patient_test_data):
    patient_root_path = next(project_root_path.rglob(f"patient/{idx_seg.lstrip('0')}"))
    img_path = list(patient_root_path.rglob(f"*.dcm"))
    for depth in range(1, model_depth):
        model_test_data_dict[f"d{depth}"] = model_test_data_dict.get(f"d{depth}", {})
        model_test_data_dict[f"d{depth}"][idx_seg] = auto_cac(
            img_path, pruned_model[f"d{depth}"]["model"], mem_opt=True
        )

### Evaluate Agatston Score

In [None]:
# Get all agatston score for ground truth and model
agatston_eval = {}


def get_total_agatston(x):
    return x["total_agatston"]


agatston_eval["ground_truth"] = np.array(
    list(map(get_total_agatston, list(ground_test_data_dict.values())))
)

for depth in range(1, model_depth):
    agatston_eval[f"d{depth}"] = np.array(
        list(map(get_total_agatston, list(model_test_data_dict[f"d{depth}"].values())))
    )

In [None]:
# Calculate
def rmse(a, b):
    n = len(a)
    return np.sqrt(np.sum(np.square(a - b)) / n)


for depth in range(1, model_depth):
    print(
        f"RMSE for model depth {depth} is {rmse(agatston_eval['ground_truth'],agatston_eval[f'd{depth}'])}"
    )

### Evaluate System Classification

In [None]:
def get_class(x):
    return x["class"]


classification_eval["ground_truth"] = list(
    map(get_class, list(ground_test_data_dict.values()))
)

for depth in range(1, model_depth):
    classification_eval[f"d{depth}"] = np.array(
        list(map(get_class, list(model_test_data_dict[f"d{depth}"].values())))
    )

In [None]:
for depth in range(1, model_depth):
    print(
        f"Accuracy for model depth {depth} is {skm.accuracy_score(classification_eval['ground_truth'],classification_eval[f'd{depth}'])}"
    )
    print(
        f"Dice for model depth {depth} is {skm.f1_score(classification_eval['ground_truth'],classification_eval[f'd{depth}'],average='micro')}"
    )
    print(
        f"Precision for model depth {depth} is {skm.precision_score(classification_eval['ground_truth'],classification_eval[f'd{depth}'],average='micro')}"
    )
    print(
        f"Recall for model depth {depth} is {skm.recall_score(classification_eval['ground_truth'],classification_eval[f'd{depth}'],average='micro')}"
    )

    skm.ConfusionMatrixDisplay(
        classification_eval["ground_truth"], classification_eval[f"d{depth}"]
    )