In [None]:
# Random Imports
import os
import sys
import yaml
from pathlib import Path
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')

# Regular schema dictates that we put DATAPATH
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))
os.environ['WANDB_NOTEBOOK_NAME'] = 'train.ipynb'

from ionpy.util import Config
from ese.scripts.utils import get_option_product

from datetime import datetime
# Get today's date
today_date = datetime.now()
# Format the date as MM_DD_YY
formatted_date = today_date.strftime("%m_%d_%y")

# Setup direcrtories
root = Path("/storage/vbutoi/scratch/ESE")
scratch_root = Path("/storage/vbutoi/scratch/ESE")
code_root = Path("/storage/vbutoi/projects/ESE")

%load_ext yamlmagic
%load_ext autoreload
%autoreload 2

In [None]:
%%yaml default_cfg 

experiment:
  seed: '?' 
    
dataloader:
  batch_size: '?' 
  num_workers: '?' 
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  lr: 1.0e-4
  weight_decay: 0.0 
  
train:
  epochs: '?' 
  eval_freq: '?' 

# Used for additional data experiment.
data:
  train_splits: train
  val_splits: val
  
loss_func: 
  _class: '?'
  from_logits: True
  batch_reduction: 'mean' 

# loss_func: 
#   classes:
#   - _class: ese.experiment.losses.SoftDiceLoss
#     from_logits: True
#     batch_reduction: 'mean' 
#   - _class: ese.experiment.losses.PixelCELoss
#     from_logits: True
#     batch_reduction: 'mean' 
#   weights:
#   - 1.0
#   - 1.0

In [None]:
%%yaml model_cfg  

model:
  _class: ese.experiment.models.unet.UNet
  filters: [64, 64, 64, 64, 64]
  convs_per_block: '?' 

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.experiment.callbacks.ShowPredictions
  epoch:
    - ese.experiment.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: dice_score
        phase: val

In [None]:
%%yaml aug_cfg
# Optional config that define the kinds of augmentations that we want to use.

# SVLS 
augmentations:
    train:
        - ese.experiment.augmentation.SVLS:
            ksize: 3
            sigma: 1
            always_apply: True
            include_center: True 

## Debug Station

In [None]:
# Setup the root.
dset = "OCTA_6M"

# Some launch params
lab = 100 
# lab = 255
loss_func = 'SoftDiceLoss'
# loss_func = 'PixelCELoss'

# exp_name = f'06_27_24_{dset}_{loss_func}_{lab}'
exp_name = f'06_27_24_{dset}_{loss_func}_{lab}_wSVLS'

# Used for launching multiple seeds.
start_seed = 40
num_seeds = 8 

exp_root = str(root / 'training' / exp_name)

# Create the ablation options
option_set = [
    {
        'log.root': [exp_root],
        'data.preload': [False],
        'data.label_threshold': [0.5],
        'data.label': [lab],
        'experiment.seed': [start_seed + seed_idx],
        'loss_func._class': [f'ese.experiment.losses.{loss_func}'],
    }
    for seed_idx in range(num_seeds)
]

In [None]:
# Load the inference cfg from local.
##################################################
cal_cfg_root = code_root / "ese" / "experiment" / "configs" / "training"
##################################################
with open(cal_cfg_root / f"{dset}.yaml", 'r') as file:
    dataset_cfg = yaml.safe_load(file)

In [None]:
# Assemble base config
# base_cfg = Config(default_cfg).update([model_cfg, dataset_cfg, callbacks_cfg])
base_cfg = Config(default_cfg).update([model_cfg, dataset_cfg, callbacks_cfg, aug_cfg])

# Get the configs
cfgs = get_option_product(exp_name, option_set, base_cfg)

In [None]:
len(cfgs)

## Running Jobs

In [None]:
######## FOR DEBUGGINV
from ese.experiment.experiment import run_ese_exp, CalibrationExperiment

run_ese_exp(
    config=cfgs[0], 
    experiment_class=CalibrationExperiment,
    gpu='0',
    # gpu='4',
    run_name='debug',
    show_examples=True,
    track_wandb=False
)

In [None]:
# # FOR SUBMISSION
# from ese.experiment.experiment import submit_ese_exps, CalibrationExperiment 

# submit_ese_exps(
#     config_list=cfgs,
#     experiment_class=CalibrationExperiment,
#     available_gpus=['0', '1', '2', '3'],
#     # available_gpus=['4', '5', '6', '7'],
#     track_wandb=True
# )