In [None]:
# Random Imports
import os
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'train.ipynb'

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

## Define said jobs

In [None]:
%%yaml default_cfg 

experiment:
  seed: 42
    
dataloader:
  batch_size: 1 
  num_workers: 2 
  pin_memory: False 

optim: 
  _class: torch.optim.Adam
  lr: 3.0e-4
  weight_decay: 0.0 
  
# For MultiClass
train:
  epochs: 100 
  eval_freq: 5
  augmentations: None

log:
  checkpoint_freq: 5 
  root: '?'
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      from_logits: True
      batch_reduction: 'mean' 
      ignore_index: 0

loss_func: 
  _class: ionpy.loss.SoftDiceLoss
  from_logits: True
  batch_reduction: 'mean' 
  ignore_index: 0 # Make sure to ignore the background class

# For Binary

# train:
#   epochs: 500 
#   eval_freq: 50

# log:
#   checkpoint_freq: 50 
#   root: '?'
#   metrics:
#     dice_score:
#       _fn: ionpy.metrics.dice_score
#       from_logits: True
#       batch_reduction: 'mean' 

# loss_func: 
#   _class: ionpy.loss.SoftDiceLoss
#   from_logits: True
#   batch_reduction: 'mean' 

In [None]:
%%yaml model_cfg  

model:
  _class: ese.experiment.models.UNet
  filters: '?'
  convs_per_block: 3

In [None]:
%%yaml dataset_cfg 

# data:
#   _class: ese.experiment.datasets.WMH
#   annotator: observer_o12
#   axis: 0
#   preload: False 
#   num_slices: 1
#   task: Amsterdam 
#   version: 0.2
#   in_channels: 1
#   out_channels: 2 

# data:
#   _class: ese.experiment.datasets.COCO 
#   in_channels: 3
#   out_channels: 80

data:
  _class: ese.experiment.datasets.OxfordPets
  preload: True
  version: 0.1
  in_channels: 3
  out_channels: 38 # 37 + 1 (background)

In [None]:
%%yaml lite_aug_cfg

- RandomAffine:
    p: 0.5
    degrees: [0, 360]
    translate: [0, 0.2]
    scale: [0.8, 1.1]
- RandomVariableElasticTransform:
    p: 0.75
    alpha: [1, 2] 
    sigma: [7, 9]
- RandomHorizontalFlip:
    p: 0.5
- RandomVerticalFlip:
    p: 0.5

In [None]:
# %%yaml standard_aug_cfg 

# - RandomCrop:
#     _class: ese.experiment.augmentation.transforms.RandomCropSegmentation
#     size: [256, 256]

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    # - ese.experiment.callbacks.ShowPredictions:
    #     mode: multiclass
    #     label_cmap: tab10 
    # - ese.experiment.callbacks.ShowPredictions:
    #     mode: binary 
    #     label_cmap: gray 
  epoch:
    # - ese.experiment.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

## Debug Station

In [None]:
# Need to define the experiment name
# exp_name = 'COCO_runs'
exp_name = 'debug'

# Create the ablation options
option_set = [
    {
        'log.root': [f'/storage/vbutoi/scratch/ESE/{exp_name}'],
        'augmentations': [lite_aug_cfg],
        'dataloader.batch_size': [1],
        'optim.lr': [3.0e-4],
        'model.filters': [[128, 128, 128, 128, 128]]
    },
]

In [None]:
from ese.scripts.utils import get_option_product
from ionpy.util import Config


# Assemble base config
#base_cfg = Config(default_cfg).update(model_cfg).update(dataset_cfg).update(transforms_cfg).update(callbacks_cfg)
base_cfg = Config(default_cfg).update([model_cfg, dataset_cfg, callbacks_cfg])

# Get the configs
cfgs = get_option_product(exp_name, option_set, base_cfg)

In [None]:
len(cfgs)

## Running Jobs

In [None]:
# Submit cell
from ionpy import slite
from ese.experiment.experiment.ese_exp import CalibrationExperiment 

In [None]:
slite.run_exp(
    config=cfgs[0], 
    exp_class=CalibrationExperiment,
    gpu='3'
)

In [None]:
# Submit the experiments
# slite.submit_exps(
#     project="ESE",
#     exp_name=exp_name,
#     exp_class=CalibrationExperiment,
#     available_gpus=['0', '1', '2', '3'],
#     config_list=cfgs
# )