In [None]:
# Random Imports
import os
import sys
import yaml
from pathlib import Path
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSegDev')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'train.ipynb'

from ionpy.util import Config

# Setup direcrtories
root = Path("/storage/vbutoi/scratch/ESE")
code_root = Path("/storage/vbutoi/projects/ESE")
scratch_root = Path("/storage/vbutoi/scratch/ESE")

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
%%yaml default_cfg 

experiment:
  seed: '?' 
    
dataloader:
  batch_size: 8 
  num_workers: 1 
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  weight_decay: 0.0 
  lr: 1.0e-4
  
train:
  epochs: 500 
  eval_freq: 10

# Used for additional data experiment.
data:
  train_splits: train
  val_splits: val
  
loss_func: 
  _class: '?'
  from_logits: True
  batch_reduction: 'mean' 

In [None]:
%%yaml log_cfg

log:
  root: '?'
  checkpoint_freq: 20 
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: 'mean' 
      ignore_empty_labels: False 
      from_logits: True
    ece_loss:
      _fn: ese.metrics.image_ece_loss
      num_prob_bins: 15
      from_logits: True

In [None]:
%%yaml model_cfg  

model:
  _class: ese.models.unet.UNet
  filters: [64, 64, 64, 64, 64]
  convs_per_block: 3 # Good default for UNets.

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.callbacks.ShowPredictions
  epoch:
    - ese.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

In [None]:
%%yaml experiment_cfg 

name: "HeptaticVessel_LowerLR"

# SVLS 
augmentations:
    train:
        - ese.augmentation.SVLS:
            ksize: 3
            sigma: 1
            always_apply: True
            include_center: False # For ACDC we don't include the center pixel. 

experiment:
    seed: 40
    seed_range: 4

optim:
    lr: 5.0e-5

## Setup for homegrown datasets.
# data:
    # _class: "ese.datasets.OCTA_6M"
    # label_threshold: 0.5
    # label: 255
    # version: 1.0 # Full resolution version.

## Setup for UniverSeg datasets.
data:
    _class: "universeg.experiment.datasets.Segment2D"
    root_folder: "MSD/thunder_MSD/v4.2"
    task: 'MSD/HepaticVessel/CT/2'
    # task: 'MSD/Pancreas/PVP-CT/2'
    resolution: 256
    label: 0

loss_func:
    _class: 
        - ese.losses.SoftDiceLoss
        - ese.losses.PixelCELoss

In [None]:
from ese.analysis.analysis_utils.submit_utils import get_ese_training_configs 

# Get the configs for the different runs.
base_cfg = Config(default_cfg).update([model_cfg, log_cfg, callbacks_cfg])

# Get the different experiment cfg yamls.
train_cfgs = get_ese_training_configs(
    exp_cfg=experiment_cfg, 
    base_cfg=base_cfg
)

In [None]:
len(train_cfgs)

In [None]:
from pprint import pprint

pprint(train_cfgs[0])

## Running Jobs

In [None]:
####### FOR DEBUGGIN
from ese.experiment import run_ese_exp, CalibrationExperiment

run_ese_exp(
    config=train_cfgs[0], 
    experiment_class=CalibrationExperiment,
    gpu='0',
    # gpu='4',
    run_name='debug',
    show_examples=True,
    track_wandb=False
)

In [None]:
# # FOR SUBMISSION
# from ese.experiment import submit_ese_exps, CalibrationExperiment 

# submit_ese_exps(
#     config_list=train_cfgs,
#     experiment_class=CalibrationExperiment,
#     available_gpus=['0', '1', '2', '3'],
#     # available_gpus=['4', '5', '6', '7'],
#     track_wandb=True
# )