# training_utils

> Utility functions for training. These include loss functions, optimizers and function to evaluate metrics.

In [None]:
#| default_exp training_utils

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.utils import *

In [None]:
#| export
from operator import gt, lt

import torch

In [None]:
#| export
losses_dict = {}

def get_loss_func(loss:str # Key into the losses dictionary
                    ):
    "Getter method to retrieve a loss function"

    assert loss in losses_dict.keys(), f'{loss} is not an existing loss function, choose one from {losses_dict.keys()}.'
    
    return losses_dict[loss]

In [None]:
#|echo: false
print("The existing keys are:\n" + "\n".join([k for k in losses_dict.keys()]))

test_eq(losses_dict.keys(), [])

The existing keys are:



In [None]:
#| export
optimizers_dict = {}

def get_optimizer(optim:str, # Key into the optimizer dictionary
                  kwargs:dict # Optimizer parameters
                    ):
    "Getter method to retrieve an optimizer"

    assert optim in optimizers_dict.keys(), f'{optim} is not an existing optimizer, choose one from {optimizers_dict.keys()}.'
    
    return optimizers_dict[optim](**kwargs)

In [None]:
#|echo: false
print("The existing keys are:\n" + "\n".join([k for k in optimizers_dict.keys()]))

test_eq(optimizers_dict.keys(), [])

The existing keys are:



In [None]:
#|export
metrics_dict = {
    'loss': lt,
    'step': gt
}

def compute_metrics(name:str,               # Name of the training stage (train, val, test)
                    outputs:torch.Tensor,   # The output of the model       
                    labels:torch.Tensor,    # The ground truth
                    loss:float,             # The loss of the model
                    example_ct:int,         # Number of examples processed by the model
                    step_ct:int,            # Number of backpropagation steps the model has done
                    epoch:float             # The training epoch
                    )->dict:                # Dictionary of the metrics
    "Compute new metrics from outputs and labels and format existing ones."

    
    return {f'{name}/loss': loss,
            f'{name}/example_ct': example_ct,
            f'{name}/step_ct': step_ct,
            f'{name}/epoch': epoch
            }

In [None]:
#|echo: false
print("The existing keys are:\n" + "\n".join([k for k in metrics_dict.keys()]))

test_eq(metrics_dict.keys(), ['loss', 'step'])

The existing keys are:
loss
step


The function in the metrics dictionary specify which operator must be used to evaluate the better metric. For instance we consider a loss to be better when is smaller.

In [None]:
test_eq(metrics_dict['loss'](0.1, 0.2), True)
test_eq(metrics_dict['loss'](0.05, 0.001), False)
test_eq(metrics_dict['step'](500, 420), True)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()