In [1]:
# Random Imports
import os
import sys
import yaml
from pathlib import Path
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSegDev')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'train.ipynb'

# IonPy imports
from ionpy.util import Config

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
%%yaml default_cfg 

experiment:
  seed: '?' 
  torch_compile: False
  torch_mixed_precision: False
    
dataloader:
  batch_size: '?' 
  num_workers: '?'
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  lr: '?'
  
train:
  epochs: '?' 
  eval_freq: '?'

# Used for additional data experiment.
data:
  train_kwargs:
    split: 'train'
  val_kwargs: 
    split: 'val'
  
loss_func: 
  _class: '?'
  from_logits: True
  batch_reduction: 'mean' 

In [None]:
%%yaml log_cfg

log:
  root: '?'
  checkpoint_freq: 20 
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: 'mean' 
      ignore_empty_labels: False 
      from_logits: True
    abs_area_estimation_error:
      _fn: ese.losses.area_estimation_error
      from_logits: True
      abs_diff: True

In [None]:
%%yaml model_cfg  

model:
  _class: ese.models.unet.UNet
  filters: '?'
  convs_per_block: '?' # Good default for UNets.

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.callbacks.ShowPredictions
  epoch:
    - ese.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

In [None]:
%%yaml aug_cfg

augmentations:
    spatial:
        max_translation: 5.0
        max_rotation: 5.0
        max_scaling: 1.1
        warp_integrations: 5
        warp_smoothing_range: [10, 20]
        warp_magnitude_range: [1, 3]
        affine_probability: 0.5
        warp_probability: 0.5
    visual:
        use_mask: False 
        added_noise_max_sigma: 0.05
        gamma_scaling_max: 0.8
        bias_field_probability: 0.5
        gamma_scaling_probability: 0.5
        added_noise_probability: 0.5

In [None]:
%%yaml experiment_cfg 

name: "10_10_24_Roads_FullRes_big_CrossEntropy"

experiment:
    seed: 40
    seed_range: 3
    torch_compile: True  

data:
    _class: "ese.datasets.Roads"
    version: 0.1

dataloader:
  # batch_size: 4
  # num_workers: 2
  batch_size: 1
  num_workers: 1

optim:
  lr: 1.0e-4 

loss_func:
    # _class: ese.losses.SoftDiceLoss
    _class: ese.losses.PixelCELoss

In [8]:
from ese.analysis.analysis_utils.submit_utils import get_ese_training_configs 

# Get the configs for the different runs.
base_cfg = Config(default_cfg).update([model_cfg, log_cfg, callbacks_cfg])
# base_cfg = Config(default_cfg).update([model_cfg, log_cfg, callbacks_cfg, aug_cfg])

# Get the different experiment cfg yamls.
updated_base_cfg, train_cfgs = get_ese_training_configs(
    exp_cfg=experiment_cfg, 
    base_cfg=base_cfg
)

In [None]:
len(train_cfgs)

In [None]:
from pprint import pprint

pprint(train_cfgs[0])

## Running Jobs

In [None]:
####### FOR DEBUGGIN
from ese.experiment import run_ese_exp, CalibrationExperiment

run_ese_exp(
    config=train_cfgs[0], 
    experiment_class=CalibrationExperiment,
    run_name='debug',
    show_examples=True,
    track_wandb=False,
    gpu='5',
)

In [None]:
# FOR SUBMISSION
from ese.experiment import submit_ese_exps, CalibrationExperiment 

submit_ese_exps(
    group="training",
    base_cfg=updated_base_cfg,
    exp_cfg=experiment_cfg,
    config_list=train_cfgs,
    experiment_class=CalibrationExperiment,
    available_gpus=['5', '6', '7'],
    track_wandb=True
)