In [6]:
import json
import os
import copy

In [7]:
base_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtV2Plans_base.json"

new_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtV2Plans_2024_08_09.json"

In [8]:
with open(base_config_fpath, "r") as f:
    base_config = json.load(f)

In [9]:
def set_model_params(
    base_config,
    batch_size = 2,
    patch_size = (20, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    stem_kernel_size = (1,3,3),
    stem_channels=None,
    stem_dilation=1,
    kernel_sizes=[
        [1, 3, 3],
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
        [1, 2, 2],
    ],
    dilation_per_stage = 1,
    n_blocks_per_stage=[1, 3, 4, 6, 6, 6, 6],
    exp_ratio_per_stage=[2, 3, 4, 4, 4, 4, 4],
    n_blocks_per_stage_decoder=[1, 1, 1, 1, 1, 1],
    exp_ratio_per_stage_decoder=[4, 4, 4, 4, 3, 2],
    norm_type = "group",
    enable_affine_transform = False,
    decoder_cat_skip=False,

    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 500,
    boundary_loss_alpha_max = 0.25,
    alpha_stepwise_warmup_scaled = True,

    oversample_foreground_percent = 1.0,
    probabilistic_oversampling = False,
    sample_class_probabilities = None,
    
    cascaded_mask_dilation = 0
):
    args_dict = locals()
    
    config_copy = copy.deepcopy(base_config)
    config_copy["is_cascaded_mask"] = True
    
    config_params = (
        "batch_size", "patch_size",
        "boundary_loss_alpha_stepsize", "boundary_loss_alpha_warmup_epochs",
        "boundary_loss_alpha_max", "alpha_stepwise_warmup_scaled",
        "oversample_foreground_percent", "probabilistic_oversampling", "sample_class_probabilities",
        "cascaded_mask_dilation"
    )
    for config_param in config_params:
        config_copy[config_param] = args_dict[config_param]
    
    arch_params = (
        "n_stages", "features_per_stage",
        "stem_kernel_size", "stem_channels", "stem_dilation",
        "kernel_sizes", "strides", "dilation_per_stage",
        "n_blocks_per_stage", "exp_ratio_per_stage",
        "n_blocks_per_stage_decoder", "exp_ratio_per_stage_decoder",
        "norm_type", "enable_affine_transform",
        "decoder_cat_skip"
    )
    
    
    arch = config_copy["architecture"]["arch_kwargs"]
    
    n_stages = len(features_per_stage)
    assert len(kernel_sizes) == n_stages
    assert len(strides) == n_stages
    assert len(n_blocks_per_stage) == n_stages
    assert len(exp_ratio_per_stage) == n_stages
    if n_blocks_per_stage_decoder is None:
        args_dict["n_blocks_per_stage_decoder"] = n_blocks_per_stage[:-1][::-1]
    assert len(args_dict["n_blocks_per_stage_decoder"]) == n_stages -1
    if args_dict["exp_ratio_per_stage_decoder"] is None:
        args_dict["exp_ratio_per_stage_decoder"] = exp_ratio_per_stage[:-1][::-1]
    assert len(args_dict["exp_ratio_per_stage_decoder"]) == n_stages -1
    args_dict["n_stages"] = n_stages
    
    for arch_param in arch_params:
        arch[arch_param] = args_dict[arch_param]
    return config_copy

def set_cascade_relationships(config, next_stages = [], prev_stage = "3d_fullres"):
#     config["configurations"][prev_stage]["next_stage"] = next_stages
    for next_stage in next_stages:
        config["configurations"][next_stage]["previous_stage"] = prev_stage

In [10]:
base_model_config = base_config["configurations"]["3d_M"]

new_config = copy.deepcopy(base_config)
new_config["plans_name"] = os.path.basename(os.path.splitext(new_config_fpath)[0])
new_config["configurations"] = {}
new_config["configurations"]["cascade_mask_dil5_16_128_GT"] = set_model_params(
    base_model_config,
    patch_size = (16, 128, 128),
    features_per_stage = (64, 128, 256, 320, 320, 320),
    stem_kernel_size = (1,3,3),
    stem_channels=None,
    stem_dilation=1,
    kernel_sizes=[
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[4, 5, 6, 6, 6, 6],
    exp_ratio_per_stage=[3, 4, 4, 4, 4, 4],
    n_blocks_per_stage_decoder=[1, 1, 1, 1, 1],
    exp_ratio_per_stage_decoder=[4, 4, 4, 4, 3],
    oversample_foreground_percent = 1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    cascaded_mask_dilation = 5
)
new_config["configurations"]["cascade_mask_dil10_16_128_GT"] = set_model_params(
    base_model_config,
    patch_size = (16, 128, 128),
    features_per_stage = (64, 128, 256, 320, 320, 320),
    stem_kernel_size = (1,3,3),
    stem_channels=None,
    stem_dilation=1,
    kernel_sizes=[
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[4, 5, 6, 6, 6, 6],
    exp_ratio_per_stage=[3, 4, 4, 4, 4, 4],
    n_blocks_per_stage_decoder=[1, 1, 1, 1, 1],
    exp_ratio_per_stage_decoder=[4, 4, 4, 4, 3],
    oversample_foreground_percent = 1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    cascaded_mask_dilation = 10
)

new_config["configurations"]["cascade_mask_dil1_16_128_GT"] = set_model_params(
    base_model_config,
    patch_size = (16, 128, 128),
    features_per_stage = (64, 128, 256, 320, 320, 320),
    stem_kernel_size = (1,3,3),
    stem_channels=None,
    stem_dilation=1,
    kernel_sizes=[
        [1, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
        [3, 3, 3],
    ],
    strides=[
        [1, 1, 1],
        [1, 2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 2, 2],
        [1, 2, 2],
    ],
    n_blocks_per_stage=[4, 5, 6, 6, 6, 6],
    exp_ratio_per_stage=[3, 4, 4, 4, 4, 4],
    n_blocks_per_stage_decoder=[1, 1, 1, 1, 1],
    exp_ratio_per_stage_decoder=[4, 4, 4, 4, 3],
    oversample_foreground_percent = 1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    cascaded_mask_dilation = 10
)

set_cascade_relationships(new_config, ["cascade_mask_dil1_16_128_GT", "cascade_mask_dil5_16_128_GT", "cascade_mask_dil10_16_128_GT"], "ground_truth_binary")

In [11]:
with open(new_config_fpath, "w") as f:
    json.dump(new_config, f, indent=2)