In [None]:
# Random Imports
import os
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
log_root_dir = '/storage/vbutoi/scratch/ESE'

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

## Initialize the runner used to run jobs.

In [None]:
# Submit cell
from ese.experiment.experiment.ese_exp import CalibrationExperiment 
from ionpy.slite import SliteRunner

# List the available gpus for a machine
available_gpus = ['0', '1', '2', '3']

# Configure Slite Object
srunner = SliteRunner(
    task_type=CalibrationExperiment, 
    available_gpus=available_gpus
    )

## Define said jobs

In [None]:
from ionpy.util.config import check_missing

def validate_cfg(cfg):
    # It's usually a good idea to do a sanity check of
    # inter-related settings or force them manually
    check_missing(cfg)        
    return cfg

In [None]:
%%yaml default_cfg 

experiment:
  seed: 42
    
dataloader:
  batch_size: '?' 
  num_workers: 4 
  pin_memory: False 

optim: 
  _class: torch.optim.Adam
  lr: 3.0e-4
  weight_decay: 0.0 
  
train:
  epochs: 500
  eval_freq: 50
  
log:
  checkpoint_freq: 50 
  root: '?'
  metrics:
    dice_score:
      _fn: ionpy.metrics.dice_score
      from_logits: True
      batch_reduction: 'mean' 

loss_func: 
  _class: ionpy.loss.SoftDiceLoss
  from_logits: True
  batch_reduction: 'mean' 

In [None]:
%%yaml model_cfg  

model:
  _class: ese.experiment.models.UNet
  in_channels: 1
  out_channels: 1
  filters: [64, 64, 64, 64, 64]
  convs_per_block: 3

In [None]:
%%yaml dataset_cfg 

data:
  _class: ese.experiment.datasets.WMH
  annotator: observer_o12
  axis: 0
  dataset: WMH  
  preload: False 
  num_slices: 1
  task: Amsterdam 
  version: 0.2

In [None]:
%%yaml lite_aug_cfg

- RandomAffine:
    p: 0.5
    degrees: [0, 360]
    translate: [0, 0.2]
    scale: [0.8, 1.1]
- RandomVariableElasticTransform:
    p: 0.75
    alpha: [1, 2] 
    sigma: [7, 9]
- RandomHorizontalFlip:
    p: 0.5
- RandomVerticalFlip:
    p: 0.5

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.experiment.callbacks.ShowPredictions
  epoch:
    - ese.experiment.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

In [None]:
# Need to define the experiment name
exp_name = 'BigUNetsRule'

# Create the ablation options
option_set = [
    {
        'log.root': [f'{log_root_dir}/debug'],
        'dataloader.batch_size': [4],
        'data.num_slices' : [2],
    },
    {
        'log.root': [f'{log_root_dir}/{exp_name}'],
        'dataloader.batch_size': [4],
        'model.filters': [
                    [128, 128, 128, 128, 128],
                    [256, 256, 256, 256, 256],
                    [512, 512, 512, 512, 512],
                ],
        'optim.weight_decay': [0, 0.00001, 0.0001],
        'dataloader.num_workers': [4]
    }
]

In [None]:
from ionpy.util import dict_product, Config
import copy

def proc_exp_name(exp_name, cfg):
    params = []
    params.append("exp_name:" + exp_name)
    for key, value in cfg.items():
        if key != "log.root":
            key_name = key.split(".")[-1]
            short_value = str(value).replace(" ", "")
            params.append(f"{key_name}:{short_value}")
    wandb_string = "-".join(params)
    return {"log.wandb_string": wandb_string}

# Assemble base config
light_augmentations = sum([copy.deepcopy(lite_aug_cfg)], start=[])
base_cfg = Config(default_cfg).update(model_cfg).update(dataset_cfg).update(callbacks_cfg)

cfgs = []
for option_dict in option_set:
    for cfg_update in dict_product(option_dict):
        cfg = base_cfg.update(cfg_update)
        cfg = cfg.update(proc_exp_name(exp_name, cfg_update))
        cfg = cfg.set('augmentations', light_augmentations)
        cfg = validate_cfg(cfg)
        cfgs.append(cfg)

# Finnally set the experiment name so we can submit jobs.
srunner.set_exp_name(exp_name)

## Run the jobs

## Debug Station

In [12]:
srunner.run_exp(cfgs[0])

KeyboardInterrupt: 

## Submit Config to Long Term Jobs

In [None]:
srunner.submit_exps(cfgs)

In [None]:
len(srunner.jobs)

In [None]:
print(srunner.jobs[3].stdout())