In [None]:
import copy
import random
import itertools
import subprocess 
from torchtrainer.util.train_util import dict_to_argv 

# All samples in the VessMAP dataset
names = [
    '5472', '2413', '3406', '7577', '4404', '12005', '10084', '3882', '15577', '15375', '8353', '17035',
    '13114', '4413', '7783', '11411', '6524', '6581', '13200', '9860', '525', '2643', '8990', '9284',
    '2050', '2071', '13128', '7865', '14440', '8196', '17880', '1643', '11558', '12943', '2546', '9452',
    '11828', '8493', '14225', '8256', '1816', '14121', '11161', '16707', '356', '12877', '6818', '10571',
    '6672', '17702', '15821', '8429', '18180', '13528', '16689', '12960', '5359', '6384', '7392', '6887',
    '8506', '1585', '4938', '458', '5801', '8686', '15160', '7413', '8065', '8284', '9593', '17584', '2849',
    '9710', '5740', '4739', '2958', '14787', '11098', '17630', '11111', '6656', '17852', '9000', '12455', '9523',
    '4909', '12618', '14778', '16295', '17425', '14690', '12749', '12335', '7083', '2287', '482', '7344', '18035',
    '16766'
]

# These parameters are common defaults for all experiments.
base_params = {
    # Logging parameters:
    "experiments_path": "/home/fonta42/Desktop/masters-degree/experiments/torch-trainer",
    "run_name": "", # Will be dynamically generated per run
    "validate_every": 50,
    "copy_model_every": 0,
    "wandb_project": "uncategorized",

    # Dataset parameters:
    "dataset_path": "/home/fonta42/Desktop/masters-degree/data/torch-trainer/VessMAP",
    "dataset_class": "vessmap_few",
    "resize_size": "256 256", # Default, can/will be overridden by model specifics
    "loss_function": "bce", # Default, can/will be overridden by variations

    # Model parameters:
    "model_class": "", # To be set specifically for each model type

    # Training parameters:
    "num_epochs": 1000,
    "validation_metric": "Dice",
    "lr": 0.001, # Default, can/will be overridden by variations
    "lr_decay": 1.0, # Default, can/will be overridden by variations
    "bs_train": 2,
    "bs_valid": 2,
    "weight_decay": 0.0, # Default, can/will be overridden by variations
    "optimizer": "adam",
    "momentum": 0.9,
    "seed": 42,

    # Device and efficiency parameters:
    "device": "cuda:0",
    "num_workers": 5,
    "benchmark": "", # Empty string means 'use benchmark if available'
}


# Define the parameters to vary and their possible values.
parameter_variations = {
    "split_strategy": [20, 90],  # Number of samples to select for the training split strategy
    "val_img_indices": ["0 1 2 3", "0 2 4"], 
    "loss_function": ["bce", "cross_entropy"],
    "lr": [0.001, 0.01],
    "lr_decay": [1.0, 0.9],
    "weight_decay": [0.0, 1e-4]
}

In [None]:
# Function to Generate Parameter Combinations
def generate_parameter_combinations(base_params, variations_dict):
    """
    Generates a list of parameter dictionaries, representing all combinations
    of the variations provided.

    Args:
        base_params (dict): Dictionary of default parameters.
        variations_dict (dict): Dictionary where keys are parameter names and
                                values are lists of possible settings.

    Returns:
        list: A list of dictionaries, each representing a unique experiment configuration.
    """
    keys = list(variations_dict.keys())
    value_lists = [variations_dict[key] for key in keys]

    all_combinations = []
    # Use itertools.product to efficiently get the Cartesian product of all value lists
    for value_combination in itertools.product(*value_lists):
        # Start with a fresh copy of the base parameters for each combination
        params = copy.deepcopy(base_params)
        # Create a dictionary for the current combination of varying parameters
        variation_params = dict(zip(keys, value_combination))
        params.update(variation_params)
        all_combinations.append(params)

    print(f"Generated {len(all_combinations)} parameter combinations.")
    return all_combinations

# Generate all parameter sets
all_params_list = generate_parameter_combinations(base_params, parameter_variations)
total_experiments = len(all_params_list)

# Define Keys to Exclude from Command Line Arguments
exclude_argv_common = ["dataset_path", "dataset_class", "model_class", "split_size"]

In [None]:
# --- Run MedSAM Experiments ---
print(f"\n--- Starting MedSAM Experiments ({total_experiments} runs) ---")
medsam_script = "./medsam_train_torchtrainer.py"
medsam_overrides = {
    "resize_size": "1024 1024",
    "model_class": "medsam",
    "experiment_name": "medsam_runs"
}

# Model specific exclusions
medsam_exclude_argv = exclude_argv_common

for i, base_combo_params in enumerate(all_params_list):
    params = copy.deepcopy(base_combo_params)

    # Apply MedSAM specific overrides
    params.update(medsam_overrides)

    # Generate the split_strategy string dynamically using the 'split_size'
    split_size = params['split_strategy']
    params['split_strategy'] = ','.join(random.sample(names, split_size))

    # Construct a unique run_name reflecting the parameters for this MedSAM run
    params["run_name"] = f"medsam_{split_size}_{params['resize_size'].replace(' ','x')}_{params['loss_function']}_{params['num_epochs']}_{params['lr']}_{params['bs_train']}_{params['bs_valid']}_{params['weight_decay']}"

    print(f"\nRunning MedSAM experiment {i+1}/{total_experiments}: {params['run_name']}")
    try:
        commandline = ' '.join(dict_to_argv(params, medsam_exclude_argv)) 
        print(f"Executing: python {medsam_script} {commandline}")

        !python {medsam_script} {commandline}

        print(f"Finished attempt for experiment: {params['run_name']}") # Log successful attempt

    except Exception as e:
        # Catch any exception during the '!' execution and report it
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print(f"!!! ERROR during MedSAM experiment: {params['run_name']}")
        print(f"!!! Error type: {type(e).__name__}")
        print(f"!!! Error details: {e}")
        print(f"!!! Skipping this run and continuing with the next experiment.")
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print(f"--- Finished MedSAM Experiment Loop ---")

In [None]:
# --- Run U-Mamba Experiments ---
print(f"\n--- Starting U-Mamba Experiments ({total_experiments} runs) ---")
umamba_script = "./umamba_train_torchtrainer.py"
umamba_overrides = {
    "resize_size": "256 256",
    "model_class": "umamba",
    "experiment_name": "umamba_runs"
}
umamba_exclude_argv = exclude_argv_common 

for i, base_combo_params in enumerate(all_params_list):
    params = copy.deepcopy(base_combo_params)
    params.update(umamba_overrides)

    split_size = params['split_strategy']
    params['split_strategy'] = ','.join(random.sample(names, split_size))

    val_indices_str = params['val_img_indices'].replace(' ','') 
    params["run_name"] = f"umamba_{split_size}_{params['resize_size'].replace(' ','x')}_{params['loss_function']}_{params['num_epochs']}_{params['lr']}_{params['bs_train']}_{params['bs_valid']}_{params['weight_decay']}_val{val_indices_str}"

    print(f"\nRunning U-Mamba experiment {i+1}/{total_experiments}: {params['run_name']}")
    try:
        commandline = ' '.join(dict_to_argv(params, umamba_exclude_argv))
        print(f"Executing: python {umamba_script} {commandline}")

        !python {umamba_script} {commandline}

        print(f"Finished attempt for experiment: {params['run_name']}") 

    except Exception as e:
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print(f"!!! ERROR during U-Mamba experiment: {params['run_name']}")
        print(f"!!! Error type: {type(e).__name__}")
        print(f"!!! Error details: {e}")
        print(f"!!! Skipping this run and continuing with the next experiment.")
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

print(f"--- Finished U-Mamba Experiment Loop ---")

In [None]:
# --- Run U-Net Experiments ---
print(f"\n--- Starting U-Net Experiments ({total_experiments} runs) ---")
unet_script = "unet_train_torchtrainer.py" 
unet_overrides = {
    "resize_size": "256 256",
    "model_class": "unet_smp", 
    "experiment_name": "unet_runs"
}
unet_exclude_argv = exclude_argv_common 

for i, base_combo_params in enumerate(all_params_list):
    params = copy.deepcopy(base_combo_params)
    params.update(unet_overrides)

    split_size = params['split_strategy']
    params['split_strategy'] = ','.join(random.sample(names, split_size))

    val_indices_str = params['val_img_indices'].replace(' ','')
    params["run_name"] = f"unet_{split_size}_{params['resize_size'].replace(' ','x')}_{params['loss_function']}_{params['num_epochs']}_{params['lr']}_{params['bs_train']}_{params['bs_valid']}_{params['weight_decay']}_val{val_indices_str}"

    print(f"\nRunning U-Net experiment {i+1}/{total_experiments}: {params['run_name']}")

    print("Running U-Net experiment with:")
    try:
        commandline = ' '.join(dict_to_argv(params, unet_exclude_argv))
        print(f"Executing: python {unet_script} {commandline}") 

        !python {unet_script} {commandline}

        print(f"Finished attempt for experiment: {params['run_name']}") 

    except Exception as e:
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print(f"!!! ERROR during U-Net experiment: {params['run_name']}")
        print(f"!!! Error type: {type(e).__name__}")
        print(f"!!! Error details: {e}")
        print(f"!!! Skipping this run and continuing with the next experiment.")
        print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

print(f"--- Finished U-Net Experiment Loop ---")


print("\n--- All experiment loops finished ---")