In [8]:
import json
import os
import copy

In [3]:
base_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtPlans.json"

new_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtPlans_2024_07_17.json"

In [4]:
with open(base_config_fpath, "r") as f:
    base_config = json.load(f)

In [5]:
base_model_config = base_config["configurations"]["3d_fullres"]

In [28]:
def set_model_params(
    base_config,
    batch_size = 2,
    patch_size = (16, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    stem_kernel_size = 1,
    kernel_sizes = [
        (1,3,3),
        (1,3,3), 
        (3,3,3),
        (3,3,3),
        (3,3,3),
        (3,3,3),
        (3,3,3),
    ],
    strides = [
        (1,1,1),
        (1,2,2),
        (1,2,2),
        (2,2,2),
        (2,2,2),
        (2,2,2),
        (2,2,2),
    ],
    n_blocks_per_stage = [3,4,6,6,6,6,6],
    exp_ratio_per_stage = [2,3,4,4,4,4,4],
    n_blocks_per_stage_decoder = None,
    exp_ratio_per_stage_decoder = None,
    norm_type = "group",
    enable_affine_transform = False,
    decode_stem_kernel_size=3,
):
    config_copy = copy.deepcopy(base_config)
    config_copy["batch_size"] = batch_size
    config_copy["patch_size"] = patch_size
    arch = config_copy["architecture"]["arch_kwargs"]
    
    n_stages = len(features_per_stage)
    assert len(kernel_sizes) == n_stages
    assert len(strides) == n_stages
    assert len(n_blocks_per_stage) == n_stages
    assert len(exp_ratio_per_stage) == n_stages
    if n_blocks_per_stage_decoder is None:
        n_blocks_per_stage_decoder = n_blocks_per_stage[::-1]
    assert len(n_blocks_per_stage_decoder) == n_stages
    if exp_ratio_per_stage_decoder is None:
        exp_ratio_per_stage_decoder = exp_ratio_per_stage[::-1]
    assert len(exp_ratio_per_stage_decoder) == n_stages
    
    arch["n_stages"] = n_stages
    arch["features_per_stage"] = features_per_stage
    arch["stem_kernel_size"] = stem_kernel_size
    arch["kernel_sizes"] = kernel_sizes
    arch["strides"] = strides
    arch["n_blocks_per_stage"] = n_blocks_per_stage
    arch["exp_ratio_per_stage"] = exp_ratio_per_stage
    arch["n_blocks_per_stage_decoder"] = n_blocks_per_stage_decoder
    arch["exp_ratio_per_stage_decoder"] = exp_ratio_per_stage_decoder
    arch["norm_type"] = norm_type
    arch["enable_affine_transform"] = enable_affine_transform
    arch["decode_stem_kernel_size"] = decode_stem_kernel_size
    return config_copy

In [29]:
baseline_model_config = set_model_params(base_model_config)

In [30]:
model_config_01 = set_model_params(
    base_model_config,
    patch_size = (32, 256, 256),
    features_per_stage = (96, 192, 384, 768),
    stem_kernel_size=4,
    kernel_sizes = [
        (1,5,5),
        (3,5,5), 
        (3,5,5),
        (1,5,5),
    ],
    strides = [
        (4,4,4),
        (1,2,2),
        (2,2,2),
        (1,2,2),
    ],
    n_blocks_per_stage = [3,3,9,3],
    exp_ratio_per_stage = [3,4,4,3],
    n_blocks_per_stage_decoder = [1,1,1,1],
    exp_ratio_per_stage_decoder = [1,1,1,1],
    decode_stem_kernel_size = 3
)

In [31]:
model_config_02 = set_model_params(
    base_model_config,
    patch_size = (32, 256, 256),
    features_per_stage = (64, 128, 256, 256),
    stem_kernel_size=4,
    kernel_sizes = [
        (1,5,5),
        (3,5,5), 
        (3,5,5),
        (1,5,5),
    ],
    strides = [
        (4,4,4),
        (1,2,2),
        (2,2,2),
        (1,2,2),
    ],
    n_blocks_per_stage = [3,3,9,3],
    exp_ratio_per_stage = [3,4,4,3],
    n_blocks_per_stage_decoder = [1,1,1,1],
    exp_ratio_per_stage_decoder = [1,1,1,1],
    decode_stem_kernel_size = 3
)
model_config_03 = set_model_params(
    base_model_config,
    patch_size = (32, 256, 256),
    features_per_stage = (64, 128, 256, 256),
    stem_kernel_size=4,
    kernel_sizes = [
        (1,3,3),
        (3,3,3), 
        (3,3,3),
        (1,3,3),
    ],
    strides = [
        (4,4,4),
        (1,2,2),
        (2,2,2),
        (1,2,2),
    ],
    n_blocks_per_stage = [3,3,9,3],
    exp_ratio_per_stage = [3,4,4,3],
    n_blocks_per_stage_decoder = [1,1,1,1],
    exp_ratio_per_stage_decoder = [1,1,1,1],
    decode_stem_kernel_size = 3
)

In [33]:
new_config = copy.deepcopy(base_config)
new_config["configurations"]["3d_fullres"] = base_model_config
new_config["configurations"]["3d_01"] = model_config_01
new_config["configurations"]["3d_02"] = model_config_02
new_config["configurations"]["3d_03"] = model_config_03

In [34]:
with open(new_config_fpath, "w") as f:
    json.dump(new_config, f, indent=2)