In [1]:
# Random Imports
import os
import sys
import yaml
from pathlib import Path
from pprint import pprint
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/libraries/voxynth')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSegDev')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'calibrate.ipynb'

# IonPy imports
from ionpy.util import Config

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [2]:
%%yaml default_cfg

experiment:
  val_first: '?' 
  torch_compile: '?' 
  torch_mixed_precision: False 

data:
  iters_per_epoch: '?' 
  train_splits: cal
  val_splits: val 

dataloader:
  batch_size: '?'
  num_workers: '?' 
  pin_memory: True 

optim: # Unclear if we should tune this or not.
  _class: torch.optim.Adam
  lr: '?'

train:
  epochs: '?'
  eval_freq: '?'
  base_pretrained_dir: '?'
  use_pretrained_norm_augs: True 
  base_checkpoint: 'max-val-dice_score'
  base_pt_select_metric: 'val-dice_score'

loss_func: 
  _class: '?'

<IPython.core.display.Javascript object>

In [3]:
%%yaml log_cfg

log:
  checkpoint_freq: 50
  root: '?'
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: 'mean' 
      ignore_empty_labels: False 
      from_logits: True
    abs_area_estimation_error:
      _fn: ese.losses.area_estimation_error
      from_logits: True
      abs_diff: True

<IPython.core.display.Javascript object>

In [4]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.callbacks.ShowPredictions
  epoch:
    - ese.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: 
          - dice_score 
          - abs_area_estimation_error
        phase: val

<IPython.core.display.Javascript object>

In [5]:
%%yaml calibrator_defaults_cfg 

TS:
  _class: ese.models.calibrators.Temperature_Scaling

LTS:
  _class: ese.models.calibrators.LocalTS
  img_channels: 1
  num_classes: 1
  dims: '?'
  filters: [8, 8, 8]

3D_LTS_Huge:
  _class: ese.models.calibrators.LocalTS
  img_channels: 1
  num_classes: 1
  use_image: True
  abs_output: '?'
  dims: 3
  convs_per_block: 2
  filters: [64, 64, 64, 64, 64]

SCTS:
  _class: ese.models.resnet.SCTS
  img_channels: 1
  num_classes: 1
  use_image: '?' 
  use_norm: '?' 
  filters: [64, 64, 64]
  blocks_per_layer: 2
  dims: 3

<IPython.core.display.Javascript object>

In [6]:
%%yaml aug_cfg

augmentations:
    spatial:
        max_translation: 5.0
        max_rotation: 5.0
        max_scaling: 1.1
        warp_integrations: 5
        warp_smoothing_range: [10, 20]
        warp_magnitude_range: [1, 3]
        affine_probability: 0.5
        warp_probability: 0.5

<IPython.core.display.Javascript object>

In [7]:
%%yaml experiment_cfg 

name: "ISLES_3D_SCTS_PredTemp_Small_wNorm"

experiment:
    val_first: False 
    torch_compile: True 

train: 
    base_pretrained_dir: "/storage/vbutoi/scratch/ESE/training/09_25_24_ISLES_3D_Dice_HeavyAug/20240925_234556-HGRQ-56cd3bf7df7a1b7d2a453b296d64407d"

data:
    _class: ese.datasets.ISLES
    target: 'temp'

model:
    class_name: SCTS
    temp_range: (0.0, 3.0)
    use_image: True
    use_norm: True 
  
optim:
    lr: 
        - 1.0e-5
        - 1.0e-4

dataloader:
    batch_size: 4
    num_workers: 3

loss_func:
    _class: torch.nn.MSELoss # If we are optimizing the temperatures directly.

<IPython.core.display.Javascript object>

In [8]:
from ese.analysis.analysis_utils.submit_utils import get_ese_calibration_configs

# Get the configs for the different runs.

base_cfg = Config(default_cfg).update([log_cfg, callbacks_cfg])
# base_cfg = Config(default_cfg).update([log_cfg, callbacks_cfg, aug_cfg])

# For exp management, need a variable because we reuse for func calls below.
add_date = True 
# Build the calibration configs from the options + base
updated_base_cfg, cal_cfgs = get_ese_calibration_configs(
    exp_cfg=experiment_cfg,
    base_cfg=base_cfg,
    calibration_model_cfgs=calibrator_defaults_cfg,
    add_date=add_date
)

In [9]:
len(cal_cfgs)

2

In [10]:
cal_cfgs[0]

Config({'experiment': {'val_first': False, 'torch_compile': True, 'torch_mixed_precision': False}, 'data': {'iters_per_epoch': None, 'train_splits': 'cal', 'val_splits': 'val', '_class': 'ese.datasets.ISLES', 'target': 'temp'}, 'dataloader': {'batch_size': 4, 'num_workers': 3, 'pin_memory': True}, 'optim': {'_class': 'torch.optim.Adam', 'lr': 1e-05}, 'train': {'epochs': 3000, 'eval_freq': 20, 'base_pretrained_dir': '/storage/vbutoi/scratch/ESE/training/09_25_24_ISLES_3D_Dice_HeavyAug/20240925_234556-HGRQ-56cd3bf7df7a1b7d2a453b296d64407d', 'use_pretrained_norm_augs': True, 'base_checkpoint': 'max-val-dice_score', 'base_pt_select_metric': 'val-dice_score'}, 'loss_func': {'_class': 'torch.nn.MSELoss'}, 'log': {'checkpoint_freq': 50, 'root': '/storage/vbutoi/scratch/ESE/calibration/10_06_24_ISLES_3D_SCTS_PredTemp_Small_wNorm', 'metrics': {'dice_score': {'_fn': 'ionpy.metrics.dice_score', 'batch_reduction': 'mean', 'ignore_empty_labels': False, 'from_logits': True}, 'abs_area_estimation_err

## Running Jobs

In [11]:
from ese.experiment import run_ese_exp, submit_ese_exps, PostHocExperiment




In [12]:
# ####### Run individual jobs
# run_ese_exp(
#     config=cal_cfgs[0], 
#     experiment_class=PostHocExperiment,
#     run_name='debug',
#     show_examples=False,
#     track_wandb=False,
#     gpu='5',
# )

In [13]:
### Run Batch Jobs
submit_ese_exps(
    group="calibration",
    base_cfg=updated_base_cfg,
    exp_cfg=experiment_cfg,
    config_list=cal_cfgs,
    experiment_class=PostHocExperiment,
    add_date=add_date,
    track_wandb=True,
    available_gpus=['6', '7']
)

Submitted job id: 2571865 on gpu: 6.
Submitted job id: 2572075 on gpu: 7.
