In [1]:
# torch imports
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
# Misc imports
import os 
import sys
import seaborn as sns
from pathlib import Path
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSegDev')
# Ionpy imports
from ionpy.util import Config
from ionpy.analysis import ResultsLoader

# Define some useful paths.
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
       '/storage'
))
# Set some defaults
rs = ResultsLoader()
sns.set_style("darkgrid")
torch.set_printoptions(linewidth=200)

# For using code without restarting.
%load_ext autoreload
%autoreload 2
# For using yaml configs.
%load_ext yamlmagic

In [2]:
%%yaml default_cfg 

experiment:
    exp_root: '?'
    inference_seed: 40

inference_data:
    split: '?'

log:
    root: '?'
    save_preds: False 
    log_interval: 10 
    log_pixel_stats: False 
    gether_inference_stats: True
    compute_global_metrics: False 

dataloader:
    batch_size: '?' 
    num_workers: '?' 
    pin_memory: True 

<IPython.core.display.Javascript object>

In [3]:
%%yaml calibration_cfg 

local_calibration:
    num_prob_bins: 15
    neighborhood_width: 3

global_calibration:
    num_classes: 1 
    num_prob_bins: 15
    neighborhood_width: 3

<IPython.core.display.Javascript object>

In [4]:
%%yaml model_cfg 

# For standard datasets
#####################################
model:
    pred_label: 0     
    _type: "standard"
    pretrained_exp_root : None
    checkpoint: 'min-val-abs_area_estimation_error'

<IPython.core.display.Javascript object>

## Gather Inference Options.

In [5]:
%%yaml experiment_cfg 

## NAMING FIELDS

group: "UVS_InContext_CrossEval"

# subgroup: "Base"
# subgroup: "Sweep_Threshold"
# subgroup: "Optimal_Dice_Threshold"
subgroup: "Optimal_RAVE_Threshold"

############################################################################################################

## EXPERIMENTAL VARIABLES

base_model:
    - "/storage/vbutoi/scratch/ESE/training/02_16_23_Universeg-HO_Adrian/20230214_112221-1C7K-2cb972b55f72d56b3842c2d2d8bdd3c6"

# Change this for debugging
dataloader:
    batch_size: 1
    num_workers: 1

experiment:
    crosseval_incontex: False 
    num_supports: 5 

inference_data:
    _class: "ese.datasets.Segment2D"
    label: 0
    split: "val"
    support_split: "train"
    task: 
        # - "ACDC/Challenge2017/MRI/2"
        # - "PanDental/v1/XRay/0"
        # - "PanDental/v2/XRay/0"
        # - "SCD/LAS/MRI/2"
        # - "SCD/VIS_human/MRI/2"
        # - "SCD/LAF_Post/MRI/2"
        # - "SCD/VIS_pig/MRI/2"
        # - "SCD/LAF_Pre/MRI/2"
        - "SpineWeb/Dataset7/MR/0"
        - "STARE/retrieved_2021_12_06/Retinal/0"
        - "WBC/CV/EM/0"
        - "WBC/JTSC/EM/0"

############################################################################################################

## Special Inference Protocols

# sweep:
#    param: "threshold" 

load_optimal_args: 
  id_key: "inference_data.task"
  sweep_key: "threshold"
  metric: "hard_RAVE"
  split: "train"
  mode: "min"

<IPython.core.display.Javascript object>

In [6]:
# Local imports
from ese.analysis.analysis_utils.submit_utils import get_ese_inference_configs

# Get the configs for the different runs.
base_cfg = Config(default_cfg).update([calibration_cfg, model_cfg])

# For inference experiments, we don't add the date by default.
add_date = True
use_best_models = False 
# Get the different experiment cfg yamls.
updated_base_cfg, inf_cfgs = get_ese_inference_configs(
    exp_cfg=experiment_cfg, 
    base_cfg=base_cfg,
    add_date=add_date,
    use_best_models=use_best_models
)

  warn(



Loading threshold sweep dataframe...
Finished loading inference stats.
Log amounts: log_root                                                                                log_set                                              
/storage/vbutoi/scratch/ESE/inference/11_05_24_UVS_InContext_CrossEval/Sweep_Threshold  20241105_011331-M2MC-863dcfaeb997edbae01b69e73ba5e3c4     5000
                                                                                        20241105_011336-2U6V-7460588958f18a8c1a6db15746e5ac29     5000
                                                                                        20241105_011340-4TIO-b13bc0d9b9bf0adaefd4154b9b1f527a     5000
                                                                                        20241105_011344-ER9T-4d09ca6cdcacf9a9a956fdaaba033364     5000
                                                                                        20241105_011436-NUKZ-d99e3038bce9ecbd5a42b555af2f3983    13600
                   

In [7]:
len(inf_cfgs)

4

## Running Jobs

In [8]:
from ese.analysis.run_inference import get_cal_stats

In [9]:
# from ese.experiment import run_ese_exp

# ###### Run individual jobs
# run_ese_exp(
#     config=inf_cfgs[0], 
#     job_func=get_cal_stats,
#     run_name='debug',
#     show_examples=False,
#     gpu='0',
# )

In [10]:
from ese.experiment import submit_ese_exps 

#### Run Batch Jobs
submit_ese_exps(
    group="inference",
    base_cfg=updated_base_cfg,
    add_date=add_date,
    exp_cfg=experiment_cfg,
    config_list=inf_cfgs,
    job_func=get_cal_stats,
    available_gpus=['0', '1', '2', '3']
)

Submitting job 1/4:
--> Launched job-id: 4129911 on gpu: 0.
Submitting job 2/4:
--> Launched job-id: 4130021 on gpu: 1.
Submitting job 3/4:
--> Launched job-id: 4130172 on gpu: 2.
Submitting job 4/4:
--> Launched job-id: 4130464 on gpu: 3.
