In [1]:
# Random Imports
import os
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'train.ipynb'

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

## Define said jobs

In [2]:
%%yaml default_cfg 

experiment:
  seed: 42
    
dataloader:
  batch_size: 8 
  num_workers: 1
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  lr: 3.0e-4
  weight_decay: 0.0 
  
train:
  epochs: 1000 
  eval_freq: 5 
  # pretrained_dir: None # In case we want to load a pretrained model.

log:
  checkpoint_freq: 20 
  root: '?'
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      from_logits: True
      batch_reduction: 'mean' 
      ignore_empty_labels: True 
      ignore_index: 0 # Ignore background class when reporting.

######################
# Cross-Entropy Loss #
######################
# loss_func: 
#   _class: ionpy.loss.PixelCELoss
#   from_logits: True
#   batch_reduction: 'mean' 
  
#############
# Dice Loss #
#############
loss_func: 
  _class: ionpy.loss.SoftDiceLoss
  from_logits: True
  batch_reduction: 'mean' 
  ignore_empty_labels: True 
  ignore_index: 0 

<IPython.core.display.Javascript object>

In [3]:
%%yaml model_cfg  

model:
  _class: ese.experiment.models.UNet
  filters: [64, 64, 64, 64, 64]
  convs_per_block: 3

<IPython.core.display.Javascript object>

In [4]:
%%yaml dataset_cfg 

#######
# WMH #
#######
data:
  _class: ese.experiment.datasets.WMH
  axis: 0
  task: Amsterdam
  slicing: dense 
  annotator: observer_o12
  num_slices: 1
  in_channels: 1
  out_channels: 2 
  version: 0.2
  iters_per_epoch: 1000 
  preload: True 

##################
# OASIS 4-Labels #
##################
# data:
#   _class: ese.experiment.datasets.OASIS
#   axis: 0
#   label_set: label4
#   slicing: central 
#   num_slices: 1
#   in_channels: 1
#   out_channels: 5 
#   central_width: 32
#   version: 0.1
#   preload: False 

###################
# OASIS 35-Labels #
###################
# data:
#   _class: ese.experiment.datasets.OASIS
#   axis: 0
#   label_set: label35
#   slicing: central 
#   num_slices: 1
#   in_channels: 1
#   out_channels: 36 
#   central_width: 32
#   version: 0.1
#   preload: False 

##############
# CityScapes #
##############
# data:
#   _class: ese.experiment.datasets.CityScapes
#   in_channels: 3
#   out_channels: 35

##############
# OxfordPets #
##############
# data:
#   _class: ese.experiment.datasets.OxfordPets
#   preload: True
#   version: 0.1
#   num_classes: all 
#   in_channels: 3
#   out_channels: 38 # 37 + 1 (background)

#####################
# Binary OxfordPets #
#####################
# data:
#   _class: ese.experiment.datasets.BinaryPets
#   preload: True
#   version: 0.1
#   in_channels: 3
#   out_channels: 2

<IPython.core.display.Javascript object>

In [5]:
%%yaml lite_aug_cfg

- RandomVariableElasticTransform:
    p: 0.5
    alpha: [1, 2] 
    sigma: [8, 10]
- RandomHorizontalFlip:
    p: 0.5
- RandomVerticalFlip:
    p: 0.5

<IPython.core.display.Javascript object>

In [6]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.experiment.callbacks.ShowPredictions
  epoch:
    - ese.experiment.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

<IPython.core.display.Javascript object>

## Debug Station

In [7]:
# Setup direcrtories
train_root = '/storage/vbutoi/scratch/ESE/training'

# WMH CONFIG
exp_name = '01_04_24_WMH_AugEnsemble'
mod_filters = [64, 64, 64, 64, 64]
dl_bs = 8
wmh_params = {
    'log.metrics.dice_score.ignore_empty_labels': [False], # Set False for WMH, True otherwise.
    'loss_func.ignore_empty_labels': [False], # Set False for WMH, True otherwise. USE FOR DICE
}

# BINARY PETS CONFIG
# exp_name = '11_20_23_BinaryPets_CrossEntropy'
# mod_filters = [64, 64, 64, 64, 64]
# dl_bs = 8

# OASIS 4-Label CONFIG
# exp_name = '11_20_23_OASIS4_CrossEntropy'
# mod_filters = [64, 64, 64, 64, 64]
# dl_bs = 8

# OASIS 35-Label CONFIG
# exp_name = '11_20_23_OASIS35_CrossEntropy'
# mod_filters = [64, 64, 64, 64, 64]
# dl_bs = 8

# CityScapes CONFIG
# exp_name = '11_20_23_CityScapes_CrossEntropy'
# mod_filters = [64, 64, 64, 64, 64]
# dl_bs = 8

# Setup the root.
exp_root = f'{train_root}/{exp_name}'

# Create the ablation options
option_set = [
    {
        'experiment.seed': [40, 41, 42, 43],
        'log.root': [exp_root],
    }
]
# Update with Dataset Specific Parameters
for option_dict in option_set:
    option_dict.update(wmh_params)

In [8]:
import copy
from ionpy.util import Config
from ionpy.util.config import check_missing
from ese.scripts.utils import get_option_product

# Assemble base config
base_cfg = Config(default_cfg).update([model_cfg, dataset_cfg, callbacks_cfg])

# Set the augmentation
light_augmentations = sum([copy.deepcopy(lite_aug_cfg)], start=[])

# Get the configs
raw_cfgs = get_option_product(exp_name, option_set, base_cfg)

cfgs = []
for cfg in raw_cfgs:
    cfg = cfg.set('augmentations', light_augmentations)
    check_missing(cfg) # Verify there are no ? in config.
    cfgs.append(cfg)

In [9]:
len(cfgs)

4

## Running Jobs

In [None]:
# ######## FOR DEBUGGING
# from ese.experiment.experiment import run_ese_exp, CalibrationExperiment

# run_ese_exp(
#     config=cfgs[0], 
#     experiment_class=CalibrationExperiment,
#     gpu='0',
#     run_name='debug',
#     show_examples=False,
#     track_wandb=True
# )

In [10]:
# FOR SUBMISSION
from ese.experiment.experiment import submit_ese_exps, CalibrationExperiment 

submit_ese_exps(
    exp_root=exp_root,
    experiment_class=CalibrationExperiment,
    config_list=[cfgs[0]],
    available_gpus=['1'],
    track_wandb=True
)

Initalized SliteRunner


Submitted job id: 338404.


In [11]:
######## FOR DEBUGGING
from ese.experiment.experiment import run_ese_exp, CalibrationExperiment

run_ese_exp(
    config=cfgs[0], 
    experiment_class=CalibrationExperiment,
    gpu='0',
    run_name='debug',
    show_examples=False,
    track_wandb=True
)

  warn("Using slow Pillow instead of Pillow-SIMD")


Running CalibrationExperiment("/storage/vbutoi/scratch/ESE/training/debug/20240104_135747-PFNH-543820e5dcdbc1a4b51974a56cf416f5")
---
augmentations:
- RandomVariableElasticTransform:
    alpha:
    - 1
    - 2
    p: 0.5
    sigma:
    - 8
    - 10
- RandomHorizontalFlip:
    p: 0.5
- RandomVerticalFlip:
    p: 0.5
callbacks:
  epoch:
  - ese.experiment.callbacks.WandbLogger
  - ionpy.callbacks.ETA
  - ionpy.callbacks.JobProgress
  - ionpy.callbacks.TerminateOnNaN
  - ionpy.callbacks.PrintLogged
  - ionpy.callbacks.ModelCheckpoint:
      monitor: dice_score
      phase: val
data:
  _class: ese.experiment.datasets.WMH
  annotator: observer_o12
  axis: 0
  iters_per_epoch: 1000
  num_slices: 1
  preload: true
  slicing: dense
  task: Amsterdam
  version: 0.2
dataloader:
  batch_size: 8
  num_workers: 1
  pin_memory: true
experiment:
  seed: 40
log:
  checkpoint_freq: 20
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: mean
      from_logits: true
    

[34m[1mwandb[0m: Currently logged in as: [33mvbutoi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Start epoch 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Checkpointing with tag:last at epoch:0
ETA (0/999): 2024-01-04 13:59:13 - 00:00:00 remaining
Logged @ Epoch 0
metric         train       val
----------  --------  --------
dice_score  0.167379  0.325349
loss        0.799555  0.626772
Checkpointing with tag:max-val-dice_score at epoch:0
Start epoch 1
ETA (1/999): 2024-01-04 19:40:19 - 05:40:45 remaining
Logged @ Epoch 1
metric         train
----------  --------
dice_score  0.358116
loss        0.593157
Start epoch 2
ETA (2/999): 2024-01-04 19:39:21 - 05:39:27 remaining
Logged @ Epoch 2
metric         train
----------  --------
dice_score  0.453615
loss        0.498473
Start epoch 3
ETA (3/999): 2024-01-04 19:39:15 - 05:39:00 remaining
Logged @ Epoch 3
metric         train
----------  --------
dice_score  0.494633
loss        0.466475
Start epoch 4


KeyboardInterrupt: 