In [None]:
import sys
sys.path.append('/storage/vbutoi/projects')
sys.path.append('/storage/vbutoi/libraries')
sys.path.append('/storage/vbutoi/projects/ESE')
sys.path.append('/storage/vbutoi/projects/UniverSeg')
import torch
torch.set_printoptions(linewidth=200)
import seaborn as sns
sns.set_style("darkgrid")

import os 
os.environ['DATAPATH'] = ':'.join((
       '/storage/vbutoi/datasets',
))

# Results loader object does everything
from ionpy.analysis import ResultsLoader
from pathlib import Path
root = Path("/storage/vbutoi/scratch/ESE")
# Setup direcrtories
calibration_root = '/storage/vbutoi/scratch/ESE/calibration'
rs = ResultsLoader()

# For using code without restarting.
%load_ext autoreload
%autoreload 
# For using yaml configs.
%load_ext yamlmagic

In [None]:
%%yaml calibration_config

data:
  preload: '?' 
  iters_per_epoch: '?'

dataloader:
  batch_size: '?' 
  num_workers: '?' 
  pin_memory: True 

optim: 
  _class: torch.optim.Adam
  lr: 1.0e-4
  weight_decay: 0.0 

train:
  epochs: 300 
  eval_freq: 10 
  pretrained_dir: '?'
  pretrained_select_metric: 'val-dice_score'

log:
  checkpoint_freq: 50 
  root: '?'
  metrics:
    ece_loss:
      _fn: ese.experiment.metrics.image_ece_loss
      from_logits: True
      num_bins: 15
      # ignore_index: 0 Useful for binary
    edge_ece_loss:
      _fn: ese.experiment.metrics.image_edge_ece_loss
      from_logits: True
      num_bins: 15
      neighborhood_width: 3
      # ignore_index: 0 Useful for binary
    elm_loss:
      _fn: ese.experiment.metrics.image_elm_loss
      from_logits: True
      num_bins: 15
      neighborhood_width: 3
      # ignore_index: 0 Useful for binary

loss_func: 
  _class: ionpy.loss.PixelCELoss 
  from_logits: True 
  batch_reduction: 'mean' 
  # Some optional stuff for Binary
  # ignore_index: 0
  # weights:
  #     - 0.5
  #     - 0.5

In [None]:
%%yaml callbacks_cfg

callbacks:
  step:
    - ese.experiment.callbacks.ShowPredictions
  epoch:
    - ese.experiment.callbacks.WandbLogger
    - ionpy.callbacks.ETA
    - ionpy.callbacks.JobProgress
    - ionpy.callbacks.TerminateOnNaN
    - ionpy.callbacks.PrintLogged
    - ionpy.callbacks.ModelCheckpoint:
        monitor: ece_loss 
        phase: val

In [None]:
%%yaml aug_cfg 

# AUGMNENTATIONS USED FOR CITYSCAPES
augmentations:
    train:
        - Resize:
            height: 256
            width: 512
        - HorizontalFlip:
            p: 0.5
    val:
        - Resize: # Maybe strange to do this, but we want to be able to compare the results.
            height: 256
            width: 512

In [None]:
%%yaml calibrator_model_cfg  

model:
  _class: '?'
  neighborhood_width: 3

In [None]:
from typing import List, Optional

def get_calibration_options(
    exp_name: str, 
    calibrator: str, 
    paths_to_calibrate: List[str], 
    preload: bool = False,
    base_options: Optional[dict] = None
):
    sub_exp_name = f'Individual_{calibrator}'
    log_root = f'{calibration_root}/{exp_name}/{sub_exp_name}'
    # Get the calibrator name
    calibrator_class_name_map = {
        "TempScaling": "ese.experiment.models.calibrators.Temperature_Scaling",
        "VectorScaling": "ese.experiment.models.calibrators.Vector_Scaling",
        "DirichletScaling": "ese.experiment.models.calibrators.Dirichlet_Scaling",
        "LTS": "ese.experiment.models.calibrators.LTS",
        "NectarScaling": "ese.experiment.models.calibrators.NECTAR_Scaling",
        "ConstrainedNS": "ese.experiment.models.calibrators.Constrained_NS"
    }
    if calibrator in calibrator_class_name_map:
        calibrator = calibrator_class_name_map[calibrator]

    calibration_options = {
        'log.root': [log_root],
        'data.preload': [preload],
        'train.pretrained_dir': paths_to_calibrate,
        'model._class': [calibrator]
    }
    if base_options is not None:
        calibration_options.update(base_options)

    # Create the ablation options
    return [calibration_options], log_root 

In [None]:
# Get the models which will be used in an ensemble.
from ese.scripts.utils import gather_exp_paths

### Calibrators
# Vanilla (no calibration technique)
# TempScaling
# VectorScaling
# DirichletScaling
# LTS
# NectarScaling

base_options = {
    'data.iters_per_epoch': [None],
    'dataloader.batch_size': [4],
    'dataloader.num_workers': [2],
}

ensemble_root = "/storage/vbutoi/scratch/ESE/training/01_25_24_CityScapes_Dice"
exp_name = "01_26_24_CityScapes_Ensemble"

option_set, log_root = get_calibration_options(
    exp_name=exp_name,
    calibrator='Vanilla',
    paths_to_calibrate=gather_exp_paths(ensemble_root),
    base_options=base_options,
    preload=False
)

In [None]:
from ese.scripts.utils import get_option_product
from ionpy.util import Config

# Assemble base config
base_cfg = Config(calibration_config).update([calibrator_model_cfg, callbacks_cfg, aug_cfg])

# Get the configs
cfgs = get_option_product(exp_name, option_set, base_cfg)

In [None]:
len(cfgs)

## Running Jobs

In [None]:
from ese.experiment.experiment import run_ese_exp, submit_ese_exps, PostHocExperiment

In [None]:
####### Run individual jobs
run_ese_exp(
    config=cfgs[0], 
    experiment_class=PostHocExperiment,
    run_name='debug',
    show_examples=False,
    track_wandb=False,
    gpu='3',
)

In [None]:
# ###### Run Batch Jobs
# submit_ese_exps(
#     exp_root=log_root,
#     experiment_class=PostHocExperiment,
#     config_list=cfgs,
#     track_wandb=True,
#     # available_gpus=['0', '1', '2', '3']
#     available_gpus=['4', '5', '6', '7']
# )