In [8]:
import json
import os
import copy

In [9]:
base_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtPlans.json"

new_config_fpath = "/home/bryan/data/mbas_nnUNet_preprocessed/Dataset101_MBAS/MedNeXtPlans_2024_08_03.json"

In [10]:
with open(base_config_fpath, "r") as f:
    base_config = json.load(f)

In [11]:
def set_model_params(
    base_config,
    batch_size = 2,
    patch_size = (16, 256, 256),
    features_per_stage = (32, 64, 128, 256, 320, 320, 320),
    stem_kernel_size = 1,
    kernel_sizes = [
        (1,3,3),
        (1,3,3), 
        (3,3,3),
        (3,3,3),
        (3,3,3),
        (3,3,3),
        (3,3,3),
    ],
    strides = [
        (1,1,1),
        (1,2,2),
        (1,2,2),
        (2,2,2),
        (2,2,2),
        (2,2,2),
        (2,2,2),
    ],
    n_blocks_per_stage = [3,4,6,6,6,6,6],
    exp_ratio_per_stage = [2,3,4,4,4,4,4],
    n_blocks_per_stage_decoder = None,
    exp_ratio_per_stage_decoder = None,
    norm_type = "group",
    enable_affine_transform = False,
    decode_stem_kernel_size=3,
    override_down_kernel_size = True,
    down_kernel_size = 1,
    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 250,
    boundary_loss_alpha_max = 0.75,
    oversample_foreground_percent = 1.0,
    probabilistic_oversampling = False,
    sample_class_probabilities = None,
    alpha_stepwise_warmup_scaled = True,
    
):
    config_copy = copy.deepcopy(base_config)
    config_copy["batch_size"] = batch_size
    config_copy["patch_size"] = patch_size
    config_copy["boundary_loss_alpha_stepsize"] = boundary_loss_alpha_stepsize
    config_copy["boundary_loss_alpha_warmup_epochs"] = boundary_loss_alpha_warmup_epochs
    config_copy["boundary_loss_alpha_max"] = boundary_loss_alpha_max
    config_copy["alpha_stepwise_warmup_scaled"] = alpha_stepwise_warmup_scaled

    config_copy["oversample_foreground_percent"] = oversample_foreground_percent
    config_copy["probabilistic_oversampling"] = probabilistic_oversampling
    config_copy["sample_class_probabilities"] = sample_class_probabilities
    

    
    arch = config_copy["architecture"]["arch_kwargs"]
    
    n_stages = len(features_per_stage)
    assert len(kernel_sizes) == n_stages
    assert len(strides) == n_stages
    assert len(n_blocks_per_stage) == n_stages
    assert len(exp_ratio_per_stage) == n_stages
    if n_blocks_per_stage_decoder is None:
        n_blocks_per_stage_decoder = n_blocks_per_stage[:-1][::-1] + [n_blocks_per_stage[0]]
    assert len(n_blocks_per_stage_decoder) == n_stages
    if exp_ratio_per_stage_decoder is None:
        exp_ratio_per_stage_decoder = exp_ratio_per_stage[:-1][::-1] + [exp_ratio_per_stage[0]]
    assert len(exp_ratio_per_stage_decoder) == n_stages
    
    arch["n_stages"] = n_stages
    arch["features_per_stage"] = features_per_stage
    arch["stem_kernel_size"] = stem_kernel_size
    arch["kernel_sizes"] = kernel_sizes
    arch["strides"] = strides
    arch["n_blocks_per_stage"] = n_blocks_per_stage
    arch["exp_ratio_per_stage"] = exp_ratio_per_stage
    arch["n_blocks_per_stage_decoder"] = n_blocks_per_stage_decoder
    arch["exp_ratio_per_stage_decoder"] = exp_ratio_per_stage_decoder
    arch["norm_type"] = norm_type
    arch["enable_affine_transform"] = enable_affine_transform
    arch["decode_stem_kernel_size"] = decode_stem_kernel_size
    arch["override_down_kernel_size"] = override_down_kernel_size
    arch["down_kernel_size"] = down_kernel_size
    return config_copy

def set_cascade_relationships(config, next_stages = [], prev_stage = "3d_fullres"):
#     config["configurations"][prev_stage]["next_stage"] = next_stages
    for next_stage in next_stages:
        config["configurations"][next_stage]["previous_stage"] = prev_stage

In [12]:
base_model_config = base_config["configurations"]["3d_fullres"]


new_config = copy.deepcopy(base_config)
new_config["plans_name"] = os.path.basename(os.path.splitext(new_config_fpath)[0])
new_config["configurations"] = {}
new_config["configurations"]["slim_128_oversample_05_alpha05_warm250_max075"] = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (32, 64, 128, 128, 128, 128, 128),
    oversample_foreground_percent=1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 250,
    boundary_loss_alpha_max = 0.75,
    alpha_stepwise_warmup_scaled = False
)
new_config["configurations"]["slim_128_oversample_05_alpha05_warm250_max050"] = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (32, 64, 128, 128, 128, 128, 128),
    oversample_foreground_percent=1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 250,
    boundary_loss_alpha_max = 0.50,
    alpha_stepwise_warmup_scaled = False,
)
new_config["configurations"]["slim_128_oversample_05_alpha05_warm250_max075_scaled"] = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (32, 64, 128, 128, 128, 128, 128),
    oversample_foreground_percent=1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 250,
    boundary_loss_alpha_max = 0.75,
    alpha_stepwise_warmup_scaled = True,
)
new_config["configurations"]["slim_128_oversample_05_alpha05_warm250_max050_scaled"] = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (32, 64, 128, 128, 128, 128, 128),
    oversample_foreground_percent=1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 250,
    boundary_loss_alpha_max = 0.50,
    alpha_stepwise_warmup_scaled = True,
)
new_config["configurations"]["slim_128_oversample_05_alpha05_warm500_max025_scaled"] = set_model_params(
    base_model_config,
    override_down_kernel_size = False,
    features_per_stage = (32, 64, 128, 128, 128, 128, 128),
    oversample_foreground_percent=1.0,
    probabilistic_oversampling = True,
    sample_class_probabilities = {1: 0.5, 2: 0.25, 3: 0.25},
    boundary_loss_alpha_stepsize = 5,
    boundary_loss_alpha_warmup_epochs = 500,
    boundary_loss_alpha_max = 0.25,
    alpha_stepwise_warmup_scaled = True,
)

In [13]:
with open(new_config_fpath, "w") as f:
    json.dump(new_config, f, indent=2)