In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')
import torch
torch.set_printoptions(linewidth=200)
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [None]:
training_exp_name = "training/01_03_24_WMH_Ensemble"

dfc = rs.load_configs(
    root / training_exp_name,
    properties=False,
)

In [None]:
%%yaml inference_config

log:
    root: '?'
    log_interval: 5
    log_image_stats: True 
    log_pixel_stats: False 
    ignore_index: 0 
    show_examples: False 

experiment:
    seed: 42

data:
    split: cal 
    slicing: dense_full 
    input_type: volume 
    preload: '?' 

dataloader:
    batch_size: 1 
    num_workers: 0
    pin_memory: True 

calibration:
    conf_interval_start: 0.5
    conf_interval_end: 1.0
    num_bins: 10
    neighborhood_width: 3
    square_diff: False 

In [None]:
%%yaml model_cfg

model:
    pretrained_exp_root : '?' 
    checkpoint: '?' 
    ensemble: '?' 
    ensemble_pre_softmax: '?'
    ensemble_combine_fn: '?' 
    pretrained_select_metric: "val-dice_score"

In [None]:
%%yaml metrics_cfg 

qual_metrics:
    # - Edge_ECE:
    #     _fn: ese.experiment.metrics.ece.edge_ece_loss
    #     metric_type: calibration
    # - ECW_ECE:
    #     _fn: ese.experiment.metrics.ece.ecw_ece_loss
    #     metric_type: calibration
    - Dice:
        _fn: ionpy.metrics.segmentation.dice_score
        from_logits: True
        batch_reduction: 'mean' 
        ignore_empty_labels: True 
        ignore_index: 0 # Ignore background class when reporting.
        metric_type: quality


# cal_metrics:
#     - ECE:
#         _fn: ese.experiment.metrics.ece.ece_loss
#     - CW_ECE:
#         _fn: ese.experiment.metrics.ece.cw_ece_loss
#     - ELM:
#         _fn: ese.experiment.metrics.elm.elm_loss
#     - CW_ELM:
#         _fn: ese.experiment.metrics.elm.cw_elm_loss

In [None]:
from ionpy.util import dict_product, Config
from ionpy.util.config import check_missing
from ese.scripts.utils import gather_exp_paths

# Get the training experiment paths.
##################################################
### For ensembles, define the root dir.
# ensemble_root = "/storage/vbutoi/scratch/ESE/training/01_08_24_WMH_Ensemble"
# pretrained_exp_paths = gather_exp_paths(ensemble_root)
# checkpoint = "max-val-dice_score" 

ensemble_root = "/storage/vbutoi/scratch/ESE/calibration/01_07_24_WMH_EnsembleLTS"
pretrained_exp_paths = gather_exp_paths(ensemble_root)
checkpoint = "min-val-ece_loss"

# Make presets for the different runnning configurations.
##################################################
# If you want to run inference on individual networks, use this.
individual_network_args = {
    'model.pretrained_exp_root': pretrained_exp_paths, # Note this is a list of train exp paths.
    'model.ensemble': [False],
    'model.ensemble_pre_softmax': [None],
    'model.ensemble_combine_fn': [None],
}

# If you want to run inference on ensembles, use this.
ensemble_network_args = {
    'model.pretrained_exp_root': [ensemble_root],
    'model.ensemble': [True],
    'model.ensemble_pre_softmax': [True, False],
    'model.ensemble_combine_fn': ['mean', 'max'],
}

# Get the inference options.
##################################################
log_root = str(root / "inference/01_08_24_WMH_EnsembleLTS")
dataset_options = {
    'log.root': [log_root],
    'model.checkpoint': [checkpoint],
    'data.preload': [False]
}

# dataset_options.update(individual_network_args)
dataset_options.update(ensemble_network_args)

In [None]:
# Build the configs.
##################################################
base_cfg = Config(inference_config).update([model_cfg, metrics_cfg])

cfgs = []
for cfg_update in dict_product(dataset_options):
    new_cfg = base_cfg.update(cfg_update)
    check_missing(new_cfg) # Verify there are no ? in config.
    cfgs.append(new_cfg)

In [None]:
len(cfgs)

## Running Jobs

In [None]:
from ese.experiment.experiment import run_ese_exp, submit_ese_exps
from ese.experiment.analysis.inference import get_cal_stats

In [None]:
###### Run individual jobs
run_ese_exp(
    config=cfgs[0], 
    job_func=get_cal_stats,
    run_name='debug',
    gpu='0',
) 

In [None]:
# ####### Run Batch Jobs
# submit_ese_exps(
#     exp_root=log_root,
#     job_func=get_cal_stats,
#     config_list=cfgs,
#     available_gpus=['0', '1', '2', '3']
# )