In [9]:
import json
import os
import copy

In [10]:
base_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtPlans.json"

new_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtPlans_2024_07_21.json"

In [11]:
with open(base_config_fpath, "r") as f:
    base_config = json.load(f)

In [12]:
base_model_config = base_config["configurations"]["3d_fullres"]

In [13]:
def set_model_params(
    base_config,
    batch_size = 2,
    patch_size = (16, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    stem_kernel_size = 1,
    kernel_sizes = [
        (1,3,3),
        (1,3,3), 
        (3,3,3),
        (3,3,3),
        (3,3,3),
        (3,3,3),
        (3,3,3),
    ],
    strides = [
        (1,1,1),
        (1,2,2),
        (1,2,2),
        (2,2,2),
        (2,2,2),
        (2,2,2),
        (2,2,2),
    ],
    n_blocks_per_stage = [3,4,6,6,6,6,6],
    exp_ratio_per_stage = [2,3,4,4,4,4,4],
    n_blocks_per_stage_decoder = None,
    exp_ratio_per_stage_decoder = None,
    norm_type = "group",
    enable_affine_transform = False,
    decode_stem_kernel_size=3,
    override_down_kernel_size = True,
    down_kernel_size = 1,
):
    config_copy = copy.deepcopy(base_config)
    config_copy["batch_size"] = batch_size
    config_copy["patch_size"] = patch_size
    arch = config_copy["architecture"]["arch_kwargs"]
    
    n_stages = len(features_per_stage)
    assert len(kernel_sizes) == n_stages
    assert len(strides) == n_stages
    assert len(n_blocks_per_stage) == n_stages
    assert len(exp_ratio_per_stage) == n_stages
    if n_blocks_per_stage_decoder is None:
        n_blocks_per_stage_decoder = n_blocks_per_stage[:-1][::-1] + [n_blocks_per_stage[0]]
    assert len(n_blocks_per_stage_decoder) == n_stages
    if exp_ratio_per_stage_decoder is None:
        exp_ratio_per_stage_decoder = exp_ratio_per_stage[:-1][::-1] + [exp_ratio_per_stage[0]]
    assert len(exp_ratio_per_stage_decoder) == n_stages
    
    arch["n_stages"] = n_stages
    arch["features_per_stage"] = features_per_stage
    arch["stem_kernel_size"] = stem_kernel_size
    arch["kernel_sizes"] = kernel_sizes
    arch["strides"] = strides
    arch["n_blocks_per_stage"] = n_blocks_per_stage
    arch["exp_ratio_per_stage"] = exp_ratio_per_stage
    arch["n_blocks_per_stage_decoder"] = n_blocks_per_stage_decoder
    arch["exp_ratio_per_stage_decoder"] = exp_ratio_per_stage_decoder
    arch["norm_type"] = norm_type
    arch["enable_affine_transform"] = enable_affine_transform
    arch["decode_stem_kernel_size"] = decode_stem_kernel_size
    arch["override_down_kernel_size"] = override_down_kernel_size
    arch["down_kernel_size"] = down_kernel_size
    return config_copy

In [34]:
slim_128 = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (32, 64, 128, 128, 128, 128, 128),
)
decoder_1_block = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    n_blocks_per_stage_decoder = [1] * 7,
)
decoder_1_exp_ratio = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    exp_ratio_per_stage_decoder = [1] * 7,
)
even_128 = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (64, 64, 128, 128, 128, 128, 128),
)

In [35]:
new_config = copy.deepcopy(base_config)
new_config["plans_name"] = "MedNeXtPlans_2024_07_21"
new_config["configurations"]["slim_128"] = slim_128
new_config["configurations"]["decoder_1_block"] = decoder_1_block
new_config["configurations"]["decoder_1_exp_ratio"] = decoder_1_exp_ratio
new_config["configurations"]["even_128"] = even_128

In [36]:
with open(new_config_fpath, "w") as f:
    json.dump(new_config, f, indent=2)