In [1]:
# Define useful paths
SCRATCH_ROOT = "/storage/vbutoi/scratch/ESE/SeBench"
CONFIG_ROOT = "/storage/vbutoi/projects/SeBench/configs"

# IonPy imports
from ionpy.util import Config

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
%%yaml default_cfg 

experiment:
  seed: '?' 
  val_first: False 
  torch_compile: True 
  torch_mixed_precision: False
  sys_paths:
    - "/storage/vbutoi/projects"
    - "/storage/vbutoi/projects/SeBench"
  data_paths:
    - "/storage"
    - "/storage/vbutoi/datasets"
    
dataloader:
  batch_size: '?' 
  num_workers: '?'
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  lr: 1.0e-4 
  
train:
  epochs: 300 
  eval_freq: 10 

# Used for additional data experiment.
data:
  train_kwargs:
    split: 'train'
  val_kwargs: 
    split: 'val'
  
loss_func: 
  _class: sebench.losses.SoftDiceLoss
  from_logits: True
  batch_reduction: 'mean' 
  ignore_empty_labels: False 

In [None]:
%%yaml log_cfg

log:
  root: '?'
  checkpoint_freq: 10 
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      batch_reduction: 'mean' 
      ignore_empty_labels: False 
      from_logits: True

In [None]:
%%yaml model_cfg  

model:
  _class: sebench.models.Segmenter
  in_channels: 1
  out_channels: 2
  dec_type: "mask"
  img_res:
    - 64 
    - 64 

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - sebench.callbacks.ShowPredictions
  epoch:
    - ionpy.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

In [6]:
# %%yaml aug_cfg

# augmentations:
#     spatial:
#         max_translation: 5.0
#         max_rotation: 5.0
#         max_scaling: 1.1
#         warp_integrations: 5
#         warp_smoothing_range: [10, 20]
#         warp_magnitude_range: [1, 3]
#         affine_probability: 0.5
#         warp_probability: 0.5
#     visual:
#         use_mask: False 
#         added_noise_max_sigma: 0.01
#         gamma_scaling_max: 0.1
#         bias_field_probability: 0.5
#         gamma_scaling_probability: 0.5
#         added_noise_probability: 0.5

In [None]:
%%yaml experiment_cfg 

group: "WBC_Segmenter"

experiment:
    seed: 40
    seed_range: 1
    torch_compile: False 

data:
    _class: "sebench.datasets.Segment2D"
    task: "WBC/CV/EM/0"
    resolution: 64

dataloader:
    batch_size: 1
    num_workers: 1

In [None]:
# from sebench.experiment.utils import get_training_configs
from sebench.scripts import get_training_configs

# Get the configs for the different runs.
base_cfg = Config(default_cfg).update([model_cfg, log_cfg, callbacks_cfg])
# base_cfg = Config(default_cfg).update([model_cfg, log_cfg, callbacks_cfg, aug_cfg])

# Get the different experiment cfg yamls.
updated_base_cfg, train_cfgs = get_training_configs(
    exp_cfg=experiment_cfg, 
    base_cfg=base_cfg,
    config_root=CONFIG_ROOT,
    scratch_root=SCRATCH_ROOT
)

In [None]:
len(train_cfgs)

## Running Jobs

In [None]:
####### FOR DEBUGGIN
from ionpy.slite import run_exp
from sebench.experiment import SegTrainExperiment

run_exp(
    config=train_cfgs[0], 
    experiment_class=SegTrainExperiment,
    run_name='debug',
    show_examples=True,
    track_wandb=False,
    gpu='0',
)

In [None]:
%%yaml submit_cfg

mode: "local"
group: "training"
add_date: True
track_wandb: True
scratch_root: "/storage/vbutoi/scratch/ESE/SeBench"

In [None]:
# # FOR SUBMISSION
# from ionpy.slite import submit_exps
# from sebench.experiment.train import SegTrainExperiment

# submit_exps(
#     submit_cfg=submit_cfg,
#     config_list=train_cfgs,
#     exp_cfg=experiment_cfg,
#     base_cfg=updated_base_cfg,
#     experiment_class=SegTrainExperiment,
#     available_gpus=['0', '1', '2', '3', '4', '5', '6', '7'],
# )