In [1]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')
import torch
torch.set_printoptions(linewidth=200)
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()


# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [2]:
training_exp_name = "training/01_03_24_WMH_Ensemble"

dfc = rs.load_configs(
    root / training_exp_name,
    properties=False,
)

  0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
%%yaml inference_config

log:
    root: '?'
    log_interval: 5
    log_image_stats: '?' 
    log_pixel_stats: '?' 
    show_examples: '?' 

experiment:
    seed: 42

data:
    split: cal 
    slicing: dense_full 
    input_type: volume 
    preload: '?' 

dataloader:
    batch_size: 1 
    num_workers: 0
    pin_memory: True 

calibration:
    conf_interval:
        - 0.5
        - 1.0
    num_bins: 10
    neighborhood_width: 3
    square_diff: False 

<IPython.core.display.Javascript object>

In [4]:
%%yaml model_cfg

model:
    pretrained_exp_root : '?' 
    checkpoint: '?' 
    ensemble: '?' 
    ensemble_pre_softmax: '?'
    ensemble_combine_fn: '?' 
    pretrained_select_metric: "val-dice_score"

<IPython.core.display.Javascript object>

In [5]:
%%yaml metrics_cfg 

qual_metrics:
    - Dice:
        _fn: ese.experiment.metrics.dice_score
        from_logits: True
        batch_reduction: 'mean' 
        ignore_empty_labels: False # This is a WMH specific setting.
        ignore_index: 0 # Ignore background class when reporting.
        metric_type: quality
    - HD95:
        _fn: ese.experiment.metrics.hd95
        from_logits: True
        batch_reduction: 'mean' 
        ignore_empty_labels: False # This is a WMH specific setting.
        ignore_index: 0 # Ignore background class when reporting.
        metric_type: quality

cal_metrics:
    - ECE:
        _fn: ese.experiment.metrics.ece.ece_loss
    # - Edge_ECE:
    #     _fn: ese.experiment.metrics.ece.edge_ece_loss
    # - CW_ECE:
    #     _fn: ese.experiment.metrics.ece.cw_ece_loss
    # - ELM:
    #     _fn: ese.experiment.metrics.elm.elm_loss
    # - Foreground_ECE:
    #     _fn: ese.experiment.metrics.ece.ece_loss
    #     ignore_index: 0
    # - Foreground_Edge_ECE:
    #     _fn: ese.experiment.metrics.ece.edge_ece_loss
    #     ignore_index: 0
    # - Foreground_CW_ECE:
    #     _fn: ese.experiment.metrics.ece.cw_ece_loss
    #     ignore_index: 0
    # - Foreground_ELM:
    #     _fn: ese.experiment.metrics.elm.elm_loss
    #     ignore_index: 0

<IPython.core.display.Javascript object>

In [6]:
from ionpy.util import dict_product, Config
from ionpy.util.config import check_missing
from ese.scripts.utils import gather_exp_paths

def get_ese_inference_configs(
        calibrator, 
        do_ensemble: bool, 
        log_image_stats: bool,
        log_pixel_stats: bool,
        group_string: str,
        preload: bool = False, 
        show_examples: bool = False,
        ):
    # Define the paths for the uncailbrated networks.
    ##################################################
    if calibrator is None:
        calibrator = "Uncalibrated"
        ensemble_root = "/storage/vbutoi/scratch/ESE/training/01_08_24_WMH_Ensemble"
        checkpoint = "max-val-dice_score" 
    # Define the paths for the calibrated networks.
    ##################################################
    else:
        ensemble_root = f"/storage/vbutoi/scratch/ESE/calibration/01_07_24_WMH_Ensemble{calibrator}"
        checkpoint = "min-val-ece_loss"

    # Set a few things that will be consistent for all runs.
    ##################################################
    default_config_options = {
        'model.checkpoint': [checkpoint],
        'data.preload': [preload],
        'log.show_examples': [show_examples],
        'log.log_image_stats': [log_image_stats],
        'log.log_pixel_stats': [log_pixel_stats]
    }

    exp_root = root / "inference" / (group_string)
    # Make presets for the different runnning configurations.
    ##################################################
    # If you want to run inference on ensembles, use this.
    if do_ensemble:
        advanced_args = {
            'log.root': [str(exp_root / f"WMH_Ensemble_{calibrator}")],
            'model.pretrained_exp_root': [ensemble_root],
            'model.ensemble': [True],
            'model.ensemble_pre_softmax': [True, False],
            'model.ensemble_combine_fn': ['mean', 'max'],
        }
    # If you want to run inference on individual networks, use this.
    else:
        advanced_args = {
            'log.root': [str(exp_root / f"WMH_Individual_{calibrator}")],
            'model.pretrained_exp_root': gather_exp_paths(ensemble_root), # Note this is a list of train exp paths.
            'model.ensemble': [False],
            'model.ensemble_pre_softmax': [None],
            'model.ensemble_combine_fn': [None],
        }
    # Combine the default and advanced arguments.
    default_config_options.update(advanced_args)
    log_root = default_config_options['log.root'][0]
    return default_config_options, log_root

In [7]:
# Calibrators
# None
# TempScaling
# VectorScaling
# DirichletScaling
# LTS

from datetime import datetime
# Get today's date
today_date = datetime.now()
# Format the date as MM_DD_YY
formatted_date = today_date.strftime("%m_%d_%y")

# Get the configs for the different runs.
dataset_options, log_root = get_ese_inference_configs(
    calibrator=None, 
    do_ensemble=False, 
    group_string=f"{formatted_date}_EnsembleAnalysis",
    log_image_stats=True,
    log_pixel_stats=True,
    show_examples=False,
    preload=False
)

In [8]:
# Build the configs.
##################################################
base_cfg = Config(inference_config).update([model_cfg, metrics_cfg])

cfgs = []
for cfg_update in dict_product(dataset_options):
    new_cfg = base_cfg.update(cfg_update)
    check_missing(new_cfg) # Verify there are no ? in config.
    cfgs.append(new_cfg)

In [9]:
len(cfgs)

4

## Running Jobs

In [10]:
from ese.experiment.experiment import run_ese_exp, submit_ese_exps
from ese.experiment.analysis.inference import get_cal_stats

In [11]:
###### Run individual jobs
run_ese_exp(
    config=cfgs[0], 
    job_func=get_cal_stats,
    run_name='debug',
    gpu='0',
) 

Set seed: 43
Set seed: 42
Local amounts:  tensor([    0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0., 65536.], dtype=torch.float64)
Local scores:  tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.5675e-05], dtype=torch.float64)

torch.Size([1, 2, 256, 256])
Global amounts:  tensor([65536.], dtype=torch.float64)
Global scores:  tensor([4.5675e-05], dtype=torch.float64)

Local amounts:  tensor([    0.,     0.,     0.,     0.,     0.,     0.,     0.,     0.,     0., 65536.], dtype=torch.float64)
Local scores:  tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9020e-05], dtype=torch.float64)

torch.Size([1, 2, 256, 256])
Global amounts:  tensor([1.0000e+00, 6.5535e+04], dtype=torch.float64)
Global scores:  tensor([9.3331e-01, 2.4779e-05], dtype=torch.float64)

Local amounts:  tensor([    0.,     0.,     0.,     0.,     0.,     0.,     

ValueError: 

In [None]:
# ####### Run Batch Jobs
# submit_ese_exps(
#     exp_root=log_root,
#     job_func=get_cal_stats,
#     config_list=cfgs,
#     available_gpus=['0', '1', '2', '3']
#     # available_gpus=['4', '5', '6', '7']
# )