In [None]:
# Random Imports
import os
import sys
import yaml
from pathlib import Path
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'calibrate.ipynb'

from ionpy.util import Config

# Setup direcrtories
root = Path("/storage/vbutoi/scratch/ESE")
scratch_root = Path("/storage/vbutoi/scratch/ESE")
code_root = Path("/storage/vbutoi/projects/ESE")

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
%%yaml default_cfg

data:
  iters_per_epoch: '?' 
  train_splits: cal
  val_splits: val 

dataloader:
  batch_size: '?' # Often, we train with a small num of images total. 
  num_workers: 1
  pin_memory: True 

optim: # Unclear if we should tune this or not.
  _class: torch.optim.Adam
  weight_decay: 0.0 
  lr: 1.0e-4

train:
  epochs: '?' # 10 * 100 = 1000 iterations
  eval_freq: 10 
  pretrained_dir: '?'
  checkpoint: 'max-val-dice_score'
  pretrained_select_metric: 'val-dice_score'

loss_func: 
  _class: '?'
  from_logits: True
  batch_reduction: 'mean' 

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.experiment.callbacks.ShowPredictions
  epoch:
    - ese.experiment.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: ece_loss 
        phase: val

In [None]:
%%yaml calibrator_cfgs 

# For calibration methods.
TempScaling:
  _class: ese.experiment.models.calibrators.Temperature_Scaling

ImageBasedTS:
  _class: ese.experiment.models.calibrators.ImageBasedTS
  img_channels: 1
  num_classes: 1

LocalTS:
  _class: ese.experiment.models.calibrators.LocalTS
  img_channels: 1
  num_classes: 1
  filters:
    - 4 
    - 4 
    - 4 

In [None]:
%%yaml experiment_cfg 

name: "ACDC_CalibratorSet"

train: 
    pretrained_dir: "/storage/vbutoi/scratch/ESE/training/07_11_24_ACDC_PixelFocalLoss"

model:
    # - TempScaling
    # - ImageBasedTS
    - LocalTS

data:
    _class: "ese.experiment.datasets.ACDC"
    label_threshold: null
    version: 0.1
    num_examples:
        - 5

dataloader:
    batch_size: 5

loss_func:
    _class: ese.experiment.losses.SoftDiceLoss
    # _class: ese.experiment.losses.PixelCELoss
    # _class: ese.experiment.losses.PixelFocalLoss
    # alpha: 0.25 
    # gamma: 2.0

In [None]:
from ese.experiment.analysis.analysis_utils.submit_utils import get_ese_calibration_configs

# Get the configs for the different runs.
base_cfg = Config(default_cfg).update([callbacks_cfg])

cal_cfgs = get_ese_calibration_configs(
    exp_cfg=experiment_cfg,
    base_cfg=base_cfg,
    calibration_model_cfgs=calibrator_cfgs
)

In [None]:
len(cal_cfgs)

## Running Jobs

In [None]:
from ese.experiment.experiment import run_ese_exp, submit_ese_exps, PostHocExperiment

In [None]:
####### Run individual jobs
run_ese_exp(
    config=cal_cfgs[0], 
    experiment_class=PostHocExperiment,
    run_name='debug',
    show_examples=True,
    track_wandb=False,
    gpu='0',
)

In [None]:
# ##### Run Batch Jobs
# submit_ese_exps(
#     config_list=cfgs,
#     experiment_class=PostHocExperiment,
#     track_wandb=True,
#     available_gpus=['0', '1', '2', '3']
# )