In [None]:
# default_exp utils

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#hide
from nbdev.showdoc import *

# VAT utils

> Helper functions and classes

In [None]:
#export
from fastai.basics import *
from fastai.test_utils import synth_learner
from fastai.callback.all import *

## Recorder customization

Patches `Recorder` to enable adding train-only metrics. Can be used to log adversarial loss.

In [None]:
@patch
def before_fit(self:Recorder):
        "Prepare state for training"
        self.lrs,self.iters,self.losses,self.values = [],[],[],[]
        names = self.metrics.attrgot('name')
        train_only_names = self.train_only_metrics.attrgot('name')
        if self.train_metrics and self.valid_metrics:
            names = L('loss') + names
            names = names.map('train_{}') + train_only_names + names.map('valid_{}')
        elif self.valid_metrics: names = L('train_loss', *train_only_names, 'valid_loss') + names
        else: names = L('train_loss') + train_only_names + names
        if self.add_time: names.append('time')
        self.metric_names = 'epoch'+names
        self.smooth_loss.reset()


In [None]:
#export
@patch(as_prop=True)
def _train_mets(self:Recorder):
    if getattr(self, 'cancel_train', False): return L()
    return L(self.smooth_loss) + self.train_only_metrics + (self.metrics if self.train_metrics else L())

@patch(as_prop=True)
def train_only_metrics(self:Recorder):
    return getattr(self, '_train_only_metrics', L())

In [None]:
#export
@patch
def add_train_metrics(self:Recorder, *metrics):
    self._train_only_metrics = self.train_only_metrics
    self._train_only_metrics += L(*metrics)

In [None]:
def _met_func():
    return 0
metric = ValueMetric(_met_func, 'nothing')
learn = synth_learner(metrics=rmse)
learn.recorder.add_train_metrics(metric)
assert learn.recorder._train_only_metrics[0] is metric
learn.fit(2)

epoch,train_loss,nothing,valid_loss,_rmse,time
0,15.411602,0,11.542326,3.3974,00:00
1,13.509046,0,8.202074,2.863926,00:00


## WandbCallback customization

1. Record the best values per metric as well as current one - used for post-analisys
2. Log additional values found `learn.log_extras` - used to log losses computed by callbacks

In [None]:
#export
try:
    import wandb
    from fastai.callback.wandb import *
    from fastai.callback.wandb import _make_plt, _format_config_value, _format_config, _format_metadata
except ImportError:
    pass

class WandbCallback(Callback):
    "Saves model topology, losses & metrics"
    remove_on_fetch,order = True,Recorder.order+1
    # Record if watch has been called previously (even in another instance)
    _wandb_watch_called = False

    def __init__(self, log="gradients", log_preds=True, log_model=True, log_dataset=False, dataset_name=None, valid_dl=None, n_preds=36, seed=12345, reorder=True, compare=None):
        # Check if wandb.init has been called
        if wandb.run is None:
            raise ValueError('You must call wandb.init() before WandbCallback()')
        # W&B log step
        self._wandb_step = wandb.run.step - 1  # -1 except if the run has previously logged data (incremented at each batch)
        self._wandb_epoch = 0 if not(wandb.run.step) else math.ceil(wandb.run.summary['epoch']) # continue to next epoch
        store_attr('log,log_preds,log_model,log_dataset,dataset_name,valid_dl,n_preds,seed,reorder,compare')

    def before_fit(self):
        "Call watch method to log model topology, gradients & weights"
        self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds") and rank_distrib()==0
        if not self.run: return

        # Log config parameters
        log_config = self.learn.gather_args()
        _format_config(log_config)
        try:
            wandb.config.update(log_config, allow_val_change=True)
        except Exception as e:
            print(f'WandbCallback could not log config parameters -> {e}')

        if not WandbCallback._wandb_watch_called:
            WandbCallback._wandb_watch_called = True
            # Logs model topology and optionally gradients and weights
            wandb.watch(self.learn.model, log=self.log)

        # log dataset
        assert isinstance(self.log_dataset, (str, Path, bool)), 'log_dataset must be a path or a boolean'
        if self.log_dataset is True:
            if Path(self.dls.path) == Path('.'):
                print('WandbCallback could not retrieve the dataset path, please provide it explicitly to "log_dataset"')
                self.log_dataset = False
            else:
                self.log_dataset = self.dls.path
        if self.log_dataset:
            self.log_dataset = Path(self.log_dataset)
            assert self.log_dataset.is_dir(), f'log_dataset must be a valid directory: {self.log_dataset}'
            metadata = {'path relative to learner': os.path.relpath(self.log_dataset, self.learn.path)}
            log_dataset(path=self.log_dataset, name=self.dataset_name, metadata=metadata)

        # log model
        if self.log_model and not hasattr(self, 'save_model'):
            print('WandbCallback requires use of "SaveModelCallback" to log best model')
            self.log_model = False

        if self.log_preds:
            try:
                if not self.valid_dl:
                    #Initializes the batch watched
                    wandbRandom = random.Random(self.seed)  # For repeatability
                    self.n_preds = min(self.n_preds, len(self.dls.valid_ds))
                    idxs = wandbRandom.sample(range(len(self.dls.valid_ds)), self.n_preds)
                    if isinstance(self.dls,  TabularDataLoaders):
                        test_items = getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[idxs]
                        self.valid_dl = self.dls.test_dl(test_items, with_labels=True, process=False)
                    else:
                        test_items = [getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[i] for i in idxs]
                        self.valid_dl = self.dls.test_dl(test_items, with_labels=True)
                self.learn.add_cb(FetchPredsCallback(dl=self.valid_dl, with_input=True, with_decoded=True, reorder=self.reorder))
            except Exception as e:
                self.log_preds = False
                print(f'WandbCallback was not able to prepare a DataLoader for logging prediction samples -> {e}')

        self.best_metrics = {}
        if self.compare is None:
            self.compare = [np.less if (('loss' in n) or ('error' in n)) else np.greater for n in self.recorder.metric_names[2:-1]]
        elif not islisty(self.compare): self.compare = [self.compare]


    def after_batch(self):
        "Log hyper-parameters and training loss"
        if self.training:
            self._wandb_step += 1
            self._wandb_epoch += 1/self.n_iter
            hypers = {f'{k}_{i}':v for i,h in enumerate(self.opt.hypers) for k,v in h.items()}
            log_extras = getattr(self.learn, 'log_extras', {})
            wandb.log({'epoch': self._wandb_epoch, 'train_loss': to_detach(self.smooth_loss.clone()), 'raw_loss': to_detach(self.loss.clone()), **hypers, **log_extras}, step=self._wandb_step)

    def log_predictions(self, preds):
        inp,preds,targs,out = preds
        b = tuplify(inp) + tuplify(targs)
        x,y,its,outs = self.valid_dl.show_results(b, out, show=False, max_n=self.n_preds)
        wandb.log(wandb_process(x, y, its, outs), step=self._wandb_step)

    def after_epoch(self):
        "Log validation loss and custom metrics & log prediction samples"
        # Correct any epoch rounding error and overwrite value
        self._wandb_epoch = round(self._wandb_epoch)
        wandb.log({'epoch': self._wandb_epoch}, step=self._wandb_step)
        # Log sample predictions
        if self.log_preds:
            try:
                self.log_predictions(self.learn.fetch_preds.preds)
            except Exception as e:
                self.log_preds = False
                print(f'WandbCallback was not able to get prediction samples -> {e}')
        metrics = {n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']}
        wandb.log(metrics, step=self._wandb_step)
        self.update_best(metrics)
        wandb.log(self.best_metrics, step=self._wandb_step)


    def after_fit(self):
        if self.log_model:
            if self.save_model.last_saved_path is None:
                print('WandbCallback could not retrieve a model to upload')
            else:
                metadata = {n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']}
                log_model(self.save_model.last_saved_path, metadata=metadata)
        self.run = True
        if self.log_preds: self.remove_cb(FetchPredsCallback)
        wandb.log({})  # ensure sync of last step
        self._wandb_step += 1

    def update_best(self, metrics):
        for is_better, (n,v) in zip(self.compare, metrics.items()):
            current_best = self.best_metrics.get(f'best_{n}', None)
            if current_best is None or is_better(v, current_best):
                self.best_metrics[f'best_{n}'] = v

## Fin

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
