In [14]:
import json
import os
import copy

In [15]:
base_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset104_MBAS/nnUNetResEncUNetMPlans.json"

new_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset104_MBAS/nnUNetResEncUNetMPlans_2024_08_10.json"

In [16]:
with open(base_config_fpath, "r") as f:
    base_config = json.load(f)

In [17]:
def set_model_params(
    base_config,
    data_identifier = None,
    batch_size = 2,
    patch_size = (20, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320),
    kernel_sizes=[
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[1, 3, 4, 6, 6, 6],
    n_conv_per_stage_decoder = [1, 1, 1, 1, 1],

    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 500,
    boundary_loss_alpha_max = 0.25,
    alpha_stepwise_warmup_scaled = True,

    oversample_foreground_percent = 1.0,
    probabilistic_oversampling = False,
    sample_class_probabilities = None,
    
    batch_dice = True,
):
    args_dict = locals()
    config_copy = copy.deepcopy(base_config)
    
    config_params = (
        "batch_size", "patch_size",
        "boundary_loss_alpha_stepsize", "boundary_loss_alpha_warmup_epochs",
        "boundary_loss_alpha_max", "alpha_stepwise_warmup_scaled",
        "oversample_foreground_percent", "probabilistic_oversampling", "sample_class_probabilities",
        "cascaded_mask_dilation",
        "batch_dice"
    )
    for config_param in config_params:
        if config_param in args_dict:
            config_copy[config_param] = args_dict[config_param]
        else:
            print(f"config_param: {config_param} not found in args")
    
    if data_identifier is not None:
        if data_identifier == "nnUNetPlans_3d_fullres":
            config_copy["spacing"] = [2.5, 0.625, 0.625]
        elif data_identifier == "3d_lowres_1.0":
            config_copy["spacing"] = [2.5, 1.0, 1.0]
        elif data_identifier == "3d_lowres_1.25":
            config_copy["spacing"] = [2.5, 1.25, 1.25]
        elif data_identifier == "3d_lowres_1.5":
            config_copy["spacing"] = [2.5, 1.5, 1.5]
        else:
            print(f"WARNING unknown data_identifier: {data_identifier}")
        config_copy["data_identifier"] = data_identifier
    
    
    arch_params = (
        "n_stages", "features_per_stage",
        "kernel_sizes",
        "n_blocks_per_stage", "n_conv_per_stage_decoder",
    )
    
    
    arch = config_copy["architecture"]["arch_kwargs"]
    
    n_stages = len(features_per_stage)
    assert len(kernel_sizes) == n_stages
    assert len(strides) == n_stages
    assert len(n_blocks_per_stage)== n_stages
    assert len(n_conv_per_stage_decoder) == n_stages - 1

    args_dict["n_stages"] = n_stages
    for arch_param in arch_params:
        arch[arch_param] = args_dict[arch_param]
    return config_copy

In [18]:
base_model_config = base_config["configurations"]["3d_lowres"]

# 3d_lowres_1.0 avg size 400
# 3d_lowres_1.0 avg size 320
# 3d_lowres_1.5 avg size 270

new_config = copy.deepcopy(base_config)
new_config["plans_name"] = os.path.basename(os.path.splitext(new_config_fpath)[0])
# new_config["configurations"] = {}
new_config["configurations"]["lowres1.0_M_16_256"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (16, 256, 256),
    data_identifier = "3d_lowres_1.0",
) # 5800mb
new_config["configurations"]["lowres1.25_M_16_256"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (16, 256, 256),
    data_identifier = "3d_lowres_1.25",
)
new_config["configurations"]["lowres1.5_M_16_256"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (16, 256, 256),
    data_identifier = "3d_lowres_1.5",
)

new_config["configurations"]["lowres1.0_M_16_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (16, 256, 256),
    data_identifier = "3d_lowres_1.0",
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6],
) # 6800
new_config["configurations"]["lowres1.25_M_16_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (16, 256, 256),
    data_identifier = "3d_lowres_1.25",
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6],
)
new_config["configurations"]["lowres1.5_M_16_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (16, 256, 256),
    data_identifier = "3d_lowres_1.5",
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6],
)

new_config["configurations"]["lowres1.0_M_40_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (40, 256, 256),
    data_identifier = "3d_lowres_1.0",
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
) # 6800
new_config["configurations"]["lowres1.25_M_40_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (40, 256, 256),
    data_identifier = "3d_lowres_1.25",
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
)
new_config["configurations"]["lowres1.5_M_40_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_lowres"],
    patch_size = (40, 256, 256),
    data_identifier = "3d_lowres_1.5",
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
)

# 7900 full res
new_config["configurations"]["fullres_M_16_256"] = set_model_params(
    base_config["configurations"]["3d_fullres"],
    patch_size = (16, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    kernel_sizes=[
        [1, 3, 3],
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[1, 3, 4, 6, 6, 6, 6],
    n_conv_per_stage_decoder = [1] * 6,
)
new_config["configurations"]["fullres_M_16_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_fullres"],
    patch_size = (16, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    kernel_sizes=[
        [1, 3, 3],
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6, 6],
    n_conv_per_stage_decoder = [1] * 6,
) # 8000mb
new_config["configurations"]["fullres_M_32_256"] = set_model_params(
    base_config["configurations"]["3d_fullres"],
    patch_size = (32, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    kernel_sizes=[
        [1, 3, 3],
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[1, 3, 4, 6, 6, 6, 6],
    n_conv_per_stage_decoder = [1] * 6,
)
new_config["configurations"]["fullres_M_32_256_nblocks3"] = set_model_params(
    base_config["configurations"]["3d_fullres"],
    patch_size = (32, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    kernel_sizes=[
        [1, 3, 3],
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[3, 3, 4, 6, 6, 6, 6],
    n_conv_per_stage_decoder = [1] * 6,
)

config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args
config_param: cascaded_mask_dilation not found in args


In [19]:
with open(new_config_fpath, "w") as f:
    json.dump(new_config, f, indent=2)