In [None]:
# default_exp core

## Core

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
#hide 
import torch.nn.functional as F
import torch as torch
from functools import partial

In [None]:
#export
def test():
    '''a test function'''
    print('test')

In [None]:
test()

test


### Loss Functions

In [None]:
#export
def leaky_loss(preds, y_true, alpha=0.05):
    '''
    objective function, including negative predictions with factor alpha
    '''
    loss_1 = (F.leaky_relu(preds, alpha).squeeze()*y_true.float()).mean()*(-1)
    loss_1.requires_grad_(True)
    return loss_1

In [None]:
preds = torch.tensor([-0.5, 0.7, 0.2, -1.5])
y_true = torch.tensor([100., 100., 100., 100.])
assert leaky_loss(preds, y_true) == (-0.5*100*0.05 + 0.7*100 + 0.2*100 + -1.5*100*0.05)/(4*-1)


In [None]:
(-0.5*100*0.05 + 0.7*100 + 0.2*100 + -1.5*100*0.05)/4

20.0

### Metrics

In [None]:
#export
def unweighted_profit(preds, y_true, threshold=0):
    '''
    metric, negative predictions ignored, y_true of positive predictions equally weighted
    '''
    m_value = ((preds.squeeze()>threshold).float()*y_true.float()).mean()
    return m_value

In [None]:
assert unweighted_profit(preds, y_true) == (-0.5*100*0 + 1*100 + 1*100 + -1.5*100*0)/(4)

In [None]:
#export
def weighted_profit(preds, y_true, threshold=0):
    '''
    metric, negative predictions ignored, results weighted by positive predictions
    adding threshold possible
    '''
    loss_1 = ((preds.squeeze()>threshold).float()*(preds.squeeze())*y_true.float()).mean()
    return loss_1

In [None]:
assert weighted_profit(preds, y_true) == (-0.5*100*0 + 0.7*100 + 0.2*100 + -1.5*100*0)/(4)

### Convenience

In [None]:
#export
def get_loss_fn(loss_fn_name, **kwargs):
    '''
    wrapper to create a partial with a more convenient __name__ attribute
    '''
    if loss_fn_name == 'leaky_loss':
        assert kwargs.get('alpha', None) is not None, 'need to specify alpha with leaky_loss'
        _loss_fn = partial(leaky_loss, alpha=kwargs['alpha'])
        _loss_fn.__name__ = loss_fn_name
        return _loss_fn
    return None

In [None]:
assert _get_loss_fn('leaky_loss', alpha=0.5)(preds, y_true) == leaky_loss(preds, y_true, alpha=0.5)

AssertionError: 