In [1]:
# Define useful paths
SCRATCH_ROOT = "/storage/vbutoi/scratch/ESE"
CONFIG_ROOT = "/storage/vbutoi/projects/ESE/configs"
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/projects/ESE')

# IonPy imports
from ionpy.util import Config

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [2]:
%%yaml default_cfg 

experiment:
  seed: 40
  seed_range: 1
  val_first: False 
  torch_compile: False 
  torch_mixed_precision: False
  sys_paths:
    - "/storage/vbutoi/projects"
    - "/storage/vbutoi/libraries"
    - "/storage/vbutoi/projects/ESE"
  data_paths:
    - "/storage"
    - "/storage/vbutoi/datasets"

dataloader:
  batch_size: 16 
  num_workers: 4 
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  lr: 1.0e-4 
  
train:
  epochs: 500 
  eval_freq: 5 

log:
  checkpoint_freq: 5 

<IPython.core.display.Javascript object>

# Define the data.

In [3]:
%%yaml base_data_cfg 

# Used for additional data experiment.
data:
  _class: 'ese.datasets.ISLES'
  train_kwargs:
    split: 'train'
  val_kwargs: 
    split: 'val'

<IPython.core.display.Javascript object>

In [4]:
%%yaml finetune_data_cfg 

# Used for additional data experiment.
data:
  use_pt_data_cfg: True
  train_kwargs:
    split: 'train'
  val_kwargs: 
    split: 'val'

<IPython.core.display.Javascript object>

# Define the Loss Function config.

In [5]:
%%yaml base_loss_cfg

loss_func: 
  _class: ese.losses.SoftDiceLoss
  from_logits: True
  batch_reduction: 'mean' 
  ignore_empty_labels: False 

<IPython.core.display.Javascript object>

In [6]:
%%yaml finetune_loss_cfg

loss_func: 
  _class: ese.losses.PixelCELoss
  from_logits: True
  batch_reduction: 'mean'

<IPython.core.display.Javascript object>

# Define the Model config.

## For base models.

In [7]:
%%yaml base_model_cfg  

model:
  _class: ese.models.unet.UNet
  filters: [64, 64, 64, 64, 64, 64, 64] 
  convs_per_block: 3 # Good default for UNets.
  dims: 3

<IPython.core.display.Javascript object>

# For downstream models (calibrators).

In [8]:
%%yaml finetune_model_cfg  

model:
  _class: '?'
  base_model_dir: '?'
  dims: 3

<IPython.core.display.Javascript object>

# Define the Callbacks config.

In [9]:
%%yaml base_callbacks_cfg

log:
  root: '?'
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: 'mean' 
      ignore_empty_labels: False 
      from_logits: True

callbacks:
  step:
    - ionpy.callbacks.ShowPredictions:
        vis_type: 'segmentation'
  epoch:
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.WandbLogger:
        entity: 'vbutoi'
        project: 'SemanticCalibration'
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

<IPython.core.display.Javascript object>

In [10]:
%%yaml finetune_callbacks_cfg

log:
  root: '?'
  metrics:
    ece_loss:
      _fn: ese.metrics.image_ece_loss
      num_prob_bins: 15
      from_logits: True
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: 'mean' 
      ignore_empty_labels: False 
      from_logits: True

callbacks:
  step:
    - ionpy.callbacks.ShowPredictions:
        vis_type: 'segmentation'
  epoch:
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.WandbLogger:
        entity: 'vbutoi'
        project: 'SemanticCalibration'
    - ionpy.callbacks.ModelCheckpoint:
        monitor: 
          - dice_score
          - ece_loss
        phase: val

<IPython.core.display.Javascript object>

# Define the config combos.

In [11]:
# Training standard segmentation models.
base_cfg = Config(default_cfg).update([
    base_data_cfg,
    base_loss_cfg,
    base_model_cfg, 
    base_callbacks_cfg
])

# Training models on top of base ones.
finetune_cfg = Config(default_cfg).update([
    finetune_data_cfg,
    finetune_loss_cfg,
    finetune_model_cfg, 
    finetune_callbacks_cfg 
])

# Experimental Variations.

In [None]:
%%yaml experiment_cfg 

group: "ISLES_CalibratorSuite"

model: 
    base_model_dir: "/storage/vbutoi/scratch/ESE/training/older_runs/2024/September_2024/09_25_24_ISLES_3D_Dice_HeavyAug/20240925_234556-HGRQ-56cd3bf7df7a1b7d2a453b296d64407d"
    _class: 
        # - "ese.models.TS"
        # - "ese.models.VS"
        # - "ese.models.DS"
        # - "ese.models.LTS"
        # - "ese.models.IBTS"

dataloader:
    batch_size: 1
    num_workers: 1

<IPython.core.display.Javascript object>

In [13]:
from ionpy.experiment.generate_configs import get_training_configs

# Get the different experiment cfg yamls.
updated_base_cfg, train_cfgs = get_training_configs(
    exp_cfg=experiment_cfg, 
    # base_cfg=base_cfg,
    base_cfg=finetune_cfg,
    config_root=CONFIG_ROOT,
    scratch_root=SCRATCH_ROOT,
    add_date=True
)

In [14]:
len(train_cfgs)

1

# Running Jobs

In [15]:
# ####### FOR DEBUGGIN
from ionpy.slite import run_exp
from ese.experiment import CalibrationExperiment

run_exp(
    config=train_cfgs[0], 
    experiment_class=CalibrationExperiment,
    run_name='debug',
    show_examples=True,
    track_wandb=False,
    gpu='0',
)

Set seed: 40


Running CalibrationExperiment("/storage/vbutoi/scratch/ESE/training/debug/20250224_174040-9B1V-78fb84b0469dc83650d0ebd2554a2298")
---
callbacks:
  epoch:
  - ionpy.callbacks.ETA
  - ionpy.callbacks.JobProgress
  - ionpy.callbacks.TerminateOnNaN
  - ionpy.callbacks.PrintLogged
  - ionpy.callbacks.ModelCheckpoint:
      monitor:
      - dice_score
      - ece_loss
      phase: val
  step:
  - ionpy.callbacks.ShowPredictions:
      vis_type: segmentation
data:
  in_channels: 1
  out_channels: 1
  preload: false
  return_data_id: true
  train_kwargs:
    split: train
  train_splits: train
  val_kwargs:
    split: val
  val_splits: val
  version: 1.0
dataloader:
  batch_size: 1
  num_workers: 1
  pin_memory: true
experiment:
  data_paths:
  - /storage
  - /storage/vbutoi/datasets
  seed: 40
  seed_range: 1
  sys_paths:
  - /storage/vbutoi/projects
  - /storage/vbutoi/libraries
  - /storage/vbutoi/projects/ESE
  torch_compile: false
  torch_mixed_precision: false
  val_first: false
log:
  ch

W0224 17:40:45.915467 140154052777792 torch/_dynamo/variables/builtin.py:775] [0/0] incorrect arg count <bound method BuiltinVariable.call_enumerate of BuiltinVariable()> got an unexpected keyword argument 'start' and no constant handler


ValueError: too many values to unpack (expected 4)

In [None]:
%%yaml submit_cfg

mode: "local"
group: "training"
add_date: True
track_wandb: True
scratch_root: "/storage/vbutoi/scratch/ESE"

In [None]:
# # FOR SUBMISSION
# from ionpy.slite import submit_exps
# from sebench.experiment.train import CalibrationExperiment 

# submit_exps(
#     submit_cfg=submit_cfg,
#     config_list=train_cfgs,
#     exp_cfg=experiment_cfg,
#     base_cfg=updated_base_cfg,
#     experiment_class=CalibrationExperiment,
#     available_gpus=['0', '1', '2', '3', '4', '5', '6', '7'],
# )