In [None]:
import pathlib
import sys

import tensorflow as tf
from tqdm import tqdm

sys.path.append(pathlib.Path.cwd().parent.as_posix())
from src.data.preprocess.lib.utils import get_patient_split
from src.models.lib.builder import build_unet_pp
from src.models.lib.config import UNetPPConfig
from src.models.lib.loss import dice_coef, log_cosh_dice_loss
from src.models.lib.utils import loss_dict_gen

project_root_path = pathlib.Path.cwd().parent

## Prepare Model

### Main Model Import

In [None]:
# Select model
model_root_path = project_root_path / "models" / "path"  # Change this
model_paths = list((model_root_path).rglob("*model*"))
model_paths

In [None]:
# Import main model
selected_model_path = model_paths[0].as_posix()
loss_func = log_cosh_dice_loss

main_model = tf.keras.models.load_model(
    selected_model_path,
    custom_objects={
        "log_cosh_dice_loss": loss_func,
        "dice_coef": dice_coef,
    },
)

### Pruned Model Creation

In [None]:
model_depth = 5
filter_list = [16, 32, 64, 128, 256]
downsample_iteration = [4, 3, 2, 2, 1]


pruned_model = {}


for depth in range(1, model_depth):
    pruned_model[f"d{depth}"] = {}

    model_config = UNetPPConfig(
        model_name=f"model_d{depth}",
        upsample_mode="upsample",
        depth=depth + 1,
        input_dim=[512, 512, 1],
        batch_norm=True,
        deep_supervision=True,
        model_mode="mobile",
        n_class={"bin": 1},
        filter_list=filter_list[: depth + 1],
        downsample_iteration=downsample_iteration[: depth + 1],
    )

    model, output_layer_name = build_unet_pp(model_config, custom=True)
    if depth != model_depth - 1:
        print(f"-- Creating pruned model d{depth}")
        for layer in tqdm(model.layers):
            pruned_layer_name = layer.name

            main_model_layer = main_model.get_layer(pruned_layer_name)

            main_model_weight = main_model_layer.get_weights()

            layer.set_weights(main_model_weight)

        pruned_model[f"d{depth}"]["model"] = model

    loss_dict = loss_dict_gen(model_config, output_layer_name, [loss_func])

    pruned_model[f"d{depth}"]["dataset"] = create_dataset(
        project_root_path, model_config, 2, 2048
    )

    pruned_model[f"d{depth}"]["config"] = model_config
    pruned_model[f"d{depth}"]["loss_dict"] = loss_dict

pruned_model[f"d{depth-1}"]["model"] = main_model

### Pruned Model Compilation

In [None]:
for depth in range(1, model_depth):
    metrics = [
        dice_coef,
        tf.keras.metrics.BinaryIoU(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.Precision(),
        tf.keras.metrics.TruePositives(),
        tf.keras.metrics.TrueNegatives(),
        tf.keras.metrics.FalseNegatives(),
        tf.keras.metrics.FalsePositives(),
    ]

    pruned_model[f"d{depth}"]["model"] = pruned_model[f"d{depth}"]["model"].compile(
        optimizer=tf.keras.optimizers.legacy.Adam(),
        loss=pruned_model[f"d{depth}"]["loss_dict"],
        metrics=metrics,
    )

### Evaluate Model

In [None]:
for depth in range(1, model_depth):
    print(f"Evaluate model d{depth} on test dataset")
    pruned_model[f"d{depth}"]["model"].evaluate(
        pruned_model[f"d{depth}"]["model"]["test"]
    )