In [1]:
# Experiment runner content
from argparse import Namespace
import os
import warnings
import json

from utils.model import inferer_registry
from experiments import run_experiments


def get_img_gts_jhu(dataset_dir):
    images_dir = os.path.join(dataset_dir, 'imagesTr')
    labels_dir = os.path.join(dataset_dir, 'labelsTr')
    imgs_gts = [
        (os.path.join(images_dir, img_path), os.path.join(labels_dir, img_path.removesuffix('_0000.nii.gz') + '.nii.gz'))
        for img_path in os.listdir(images_dir)  # Adjust the extension as needed
        if os.path.exists(os.path.join(labels_dir, img_path.rstrip('_0000.nii.gz') + '.nii.gz'))
    ]
    return(imgs_gts)

def get_imgs_gts_amos(dataset_dir):
    images_dir = os.path.join(dataset_dir, 'imagesTs')
    labels_dir = os.path.join(dataset_dir, 'labelsTs')
    imgs_gts = [
        (os.path.join(images_dir, img_path), os.path.join(labels_dir, os.path.basename(img_path)))
        for img_path in os.listdir(images_dir)  # Adjust the extension as needed
        if os.path.exists(os.path.join(labels_dir, os.path.basename(img_path)))
    ]
    return(imgs_gts)

def get_imgs_gts_segrap(dataset_dir):
    images_dir = os.path.join(dataset_dir, 'imagesTr')
    labels_dir = os.path.join(dataset_dir, 'labelsTr')
    imgs_gts = [
        (os.path.join(images_dir, img_path), os.path.join(labels_dir, img_path.removesuffix('_0000.nii.gz') + '.nii.gz'))
        for img_path in os.listdir(images_dir)  # Adjust the extension as needed
        if os.path.exists(os.path.join(labels_dir, img_path.removesuffix('_0000.nii.gz') + '.nii.gz'))
    ]
    return(imgs_gts)

checkpoint_registry = {
    'sam': '/home/t722s/Desktop/UniversalModels/TrainedModels/sam_vit_h_4b8939.pth',
    'sammed2d': '/home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth'
}

dataset_registry={
    'abdomenAtlas': {'dir':'/home/t722s/Desktop/Datasets/Dataset350_AbdomenAtlasJHU_2img/', 'dataset_func': get_img_gts_jhu},
    'segrap': {'dir': '/home/t722s/Desktop/Datasets/segrapSub/', 'dataset_func': get_imgs_gts_segrap}
}

if __name__ == '__main__':
    # Setup
    # warnings.filterwarnings('error')

    dataset_name = 'abdomenAtlas'
    model_name = 'sammed2d'
    results_dir = '/home/t722s/Desktop/ExperimentResults'
    

    exp_params = Namespace(
        n_click_random_points = 5,
        n_slice_point_interpolation = 5,
        n_slice_box_interpolation = 5,
        n_seed_points_point_propagation = 5, n_points_propagation = 5,
        dof_bound = 60,
        perf_bound = 0.85,
    )
    device = 'cuda'
    seed = 11121
    label_overwrite = None
    experiment_overwrite = None

    prompt_types = ['interactive'] #['points', 'boxes', 'interactive']

    label_overwrite = {
        "kidney_left": 3,
    }
    
    # label_overwrite = {
    #     "background": 0,
    #     "aorta": 1,
    #     "gall_bladder": 2,
    #     "kidney_left": 3,
    #     "kidney_right": 4,
    #     "liver": 5,
    #     "pancreas": 6,
    #     "postcava": 7,
    #     "spleen": 8,
    #     "stomach": 9
    # }

    

    # experiment_overwrite = ['box_propagation']    


    # Get (img path, gt path) pairs
    results_path = os.path.join(results_dir, model_name + '_' + dataset_name + '.json')
    dataset_func, dataset_dir = dataset_registry[dataset_name]['dataset_func'], dataset_registry[dataset_name]['dir']
    imgs_gts = dataset_func(dataset_dir)

    # Get dataset dict if missing
    with open(os.path.join(dataset_dir, 'dataset.json'), 'r') as f:
        dataset_info = json.load(f)
    label_dict = dataset_info['labels']

    if label_overwrite:
        label_dict = label_overwrite

    # Load the model
    #inferer = SAMInferer(checkpoint_path, device)
    checkpoint_path = checkpoint_registry[model_name]
    inferer = inferer_registry[model_name](checkpoint_path, device)

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.11 (you have 1.4.8). Upgrade using: pip install --upgrade albumentations
  from .autonotebook import tqdm as notebook_tqdm


True
*******load /home/t722s/Desktop/UniversalModels/TrainedModels/sam-med2d_b.pth


In [2]:
# Experiments content
import os
import numpy as np
import json
import utils.analysis as anUt
import utils.prompt as prUt
from utils.interactivity import iterate_2d
from utils.image import read_reorient_nifti
from tqdm import tqdm
import warnings
def run_experiments(inferer, imgs_gts, results_path, label_dict,
                    exp_params, prompt_types,
                    seed, experiment_overwrite = None):
    
    inferer.verbose = False # No need for progress bars per inference

    # Define experiments
    experiments = {}

    if 'points' in prompt_types:
        experiments.update({
            'random_points': lambda organ_mask: prUt.get_pos_clicks2D_row_major(organ_mask, exp_params.n_click_random_points, seed=seed),
            'point_interpolation': lambda organ_mask: prUt.point_interpolation(prUt.get_fg_points_from_cc_centers(organ_mask, exp_params.n_slice_point_interpolation)),
            'point_propagation': lambda organ_mask, slices_to_infer: prUt.point_propagation(inferer, img, prUt.get_seed_point(organ_mask, exp_params.n_seed_points_point_propagation, seed), 
                                                                slices_to_infer, seed, exp_params.n_points_propagation, verbose = False),
        })

    if 'boxes' in prompt_types:
        experiments.update({
            'bounding_boxes': lambda organ_mask: prUt.get_minimal_boxes_row_major(organ_mask),
            'bbox3d_sliced': lambda organ_mask: prUt.get_bbox3d_sliced(organ_mask),
            'box_interpolation': lambda organ_mask: prUt.box_interpolation(prUt.get_seed_boxes(organ_mask, exp_params.n_slice_box_interpolation)),
            'box_propagation': lambda organ_mask, slices_to_infer: prUt.box_propagation(inferer, img, prUt.get_seed_box(organ_mask), slices_to_infer, verbose = False)
        })

    interactive_experiments = {}
    if 'interactive' in prompt_types:
        interactive_experiments.update({
            'point_interpolation_interactive': lambda organ_mask: prUt.point_interpolation(prUt.get_fg_points_from_cc_centers(organ_mask, exp_params.n_slice_point_interpolation)),
            'point_propagation_interactive': lambda organ_mask, slices_to_infer: prUt.point_propagation(inferer, img, prUt.get_seed_point(organ_mask, exp_params.n_seed_points_point_propagation, seed), 
                                                                slices_to_infer, seed, exp_params.n_points_propagation, verbose = False),
        })

    # Debugging: Overwrite experiments
    if experiment_overwrite:
        experiments = {ex: experiments[ex] for ex in experiment_overwrite if ex in experiments.keys()}
        interactive_experiments = {ex: experiments[ex] for ex in experiment_overwrite if ex in interactive_experiments.keys()}

    experiment_names = list(experiments.keys()) + list(interactive_experiments.keys())


    # Initialize results dictionary
    results = {exp_name: {label: {} for label in label_dict if label != "background"} for exp_name in experiment_names}


    # Loop through all image and label pairs
    #for filename in tqdm(os.listdir(images_dir), 'looping through files'):
    for img_path, gt_path in tqdm(imgs_gts, desc = 'looping through files\n'):
        base_name = os.path.basename(img_path)
        img, gt = read_reorient_nifti(img_path).astype(np.float32), read_reorient_nifti(gt_path).astype(int)

        # Loop through each organ label except the background
        for label_name, label_val in tqdm(label_dict.items(), desc = 'looping through organs\n', leave = False):
            if label_name == 'background':
                continue

            organ_mask = np.where(gt == label_val, 1, 0)
            if not np.any(organ_mask):  # Skip if no foreground for this label
                warnings.warn(f'{gt_path} missing segmentation for {label_name}')
                continue

            slices_to_infer = np.where(np.any(organ_mask, axis=(1, 2)))[0]
            
            # Handle non-interactive experiments
            for exp_name, prompting_func in tqdm(experiments.items(), desc = 'looping through non_interactive experiments', leave = False):
                if exp_name in ['point_propagation', 'box_propagation']: 
                    segmentation, prompt = prompting_func(organ_mask, slices_to_infer)
                else:
                    prompt = prompting_func(organ_mask)
                    segmentation = inferer.predict(img, prompt)
                dice_score = anUt.compute_dice(segmentation, organ_mask)
                results[exp_name][label_name][base_name] = dice_score

            # Now handle interactive experiments
            for exp_name, prompting_func in tqdm(interactive_experiments.items(), desc = 'looping through interactive experiments'):
                # Set the few things that differ depending on the seed method
                if exp_name in ['point_propagation_interactive', 'box_propagation_interactive']: 
                    segmentation, prompt = prompting_func(organ_mask, slices_to_infer)
                    init_dof = 5
                else:
                    prompt = prompting_func(organ_mask)
                    segmentation = inferer.predict(img, prompt)
                    init_dof = 9
                
                dice_scores, dofs = iterate_2d(inferer, img, organ_mask, segmentation, prompt, inferer.pass_prev_prompts,
                                                                        scribble_length = 0.6, contour_distance = 3, disk_size_range= (0,3),
                                                                        init_dof = init_dof, perf_bound = exp_params.perf_bound, dof_bound = exp_params.dof_bound, seed = seed, verbose = False)
                
                results[exp_name][label_name][base_name] = {'dof': dofs, 'dice_scores': dice_scores}               
                

            inferer.clear_embeddings()

    # Save results 
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=4)

    print(f"Results saved to {results_path}")


looping through files
:   0%|          | 0/2 [00:00<?, ?it/s]


looping through organs
[A
[A
[A

looping through interactive experiments: 100%|██████████| 2/2 [00:18<00:00,  9.48s/it]

looping through filess
:  50%|█████     | 1/2 [00:20<00:20, 20.67s/it]
looping through organs
[A
[A
[A
[A
looping through interactive experiments: 100%|██████████| 2/2 [00:10<00:00,  5.41s/it]

looping through filess
looping through files2 [00:32<00:00, 15.33s/it]
: 100%|██████████| 2/2 [00:32<00:00, 16.13s/it]

Results saved to /home/t722s/Desktop/ExperimentResults/sammed2d_abdomenAtlas.json



