In [None]:
# default_exp losses

# Losses

> Where all the losses are situated

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import torch
def num_tries_gt_zero(scores, batch_size, max_trials, max_num, device):
    '''
    scores: [batch_size x N] float scores
    returns: [batch_size x 1] the lowest indice per row where scores were first greater than 0. plus 1
    '''
    tmp = scores.gt(0).nonzero().t()
    # We offset these values by 1 to look for unset values (zeros) later
    values = tmp[1] + 1
    # TODO just allocate normal zero-tensor and fill it?
    # Sparse tensors can't be moved with .to() or .cuda() if you want to send in cuda variables first
    if device.type == 'cuda':
        t = torch.cuda.sparse.LongTensor(tmp, values, torch.Size((batch_size, max_trials+1))).to_dense()
    else:
        t = torch.sparse.LongTensor(tmp, values, torch.Size((batch_size, max_trials+1))).to_dense()
    t[(t == 0)] += max_num # set all unused indices to be max possible number so its not picked by min() call

    tries = torch.min(t, dim=1)[0]
    return tries


def warp_loss(positive_predictions, negative_predictions, num_labels, device):
    '''
    positive_predictions: [batch_size x 1] floats between -1 to 1
    negative_predictions: [batch_size x N] floats between -1 to 1
    num_labels: int total number of labels in dataset (not just the subset you're using for the batch)
    device: pytorch.device
    '''
    batch_size, max_trials = negative_predictions.size(0), negative_predictions.size(1)

    offsets, ones, max_num = (torch.arange(0, batch_size, 1).long().to(device) * (max_trials + 1),
                              torch.ones(batch_size, 1).float().to(device),
                              batch_size * (max_trials + 1) )

    sample_scores = (1 + negative_predictions - positive_predictions)
    # Add column of ones so we know when we used all our attempts, This is used for indexing and computing should_count_loss if no real value is above 0
    sample_scores, negative_predictions = (torch.cat([sample_scores, ones], dim=1),
                                           torch.cat([negative_predictions, ones], dim=1))

    tries = num_tries_gt_zero(sample_scores, batch_size, max_trials, max_num, device)
    attempts, trial_offset = tries.float(), (tries - 1) + offsets
    loss_weights, should_count_loss = ( torch.log(torch.floor((num_labels - 1) / attempts)),
                                        (attempts <= max_trials).float()) #Don't count loss if we used max number of attempts

    losses = loss_weights * ((1 - positive_predictions.view(-1)) + negative_predictions.view(-1)[trial_offset]) * should_count_loss

    return losses.sum()


In [None]:
#hide
import numpy as np


cpu_device = torch.device('cpu')
max_value = num_labels = 10

## warp-loss tests
def ground_truth(pos_val, neg_val, num_attempts=1, num_labels=10):
    num_labels -= 1
    loss_weight = np.log(np.floor(num_labels / float(num_attempts)))

    return loss_weight * ((1-pos_val) + neg_val)

comp_pos = torch.FloatTensor([[0.1], [0.9], [1]])
comp_neg = torch.FloatTensor([[-1, 0.3, 0.5], [-1, -1, 0.3], [0.3, 0.5, -1]])
comp_scores = (1 + comp_neg) - comp_pos

def test_num_tries():
    simple = torch.FloatTensor([[0.5, -0.5], [-0.5, 0.5]])
    res = num_tries_gt_zero(simple, 2, 2, max_value, cpu_device)
    ans = torch.LongTensor([1, 2])
    for i, v in enumerate(res.long()):
        assert v == ans[i]

    res = num_tries_gt_zero(comp_scores, 3, 3, max_value, cpu_device)
    ans = torch.LongTensor([2, 3, 1])
    for i, v in enumerate(res.long()):
        assert v == ans[i]


def test_ground_truth():
    # these variables should always trigger on first index
    pos = torch.rand(2, 1).to(cpu_device)
    neg = torch.rand(2, 3).to(cpu_device)
    res = warp_loss(pos.view(-1, 1), neg, num_labels, cpu_device)

    assert res == ground_truth(pos[0], neg[0][0], num_labels=num_labels) + ground_truth(pos[1], neg[1][0], num_labels=num_labels)

    res = warp_loss(comp_pos.view(-1, 1), comp_neg, num_labels, cpu_device)

    gt = np.sum(np.array([ground_truth(comp_pos[i], comp_neg[i][idx-1], num_attempts=idx, num_labels=num_labels)
                            for i, idx in enumerate(num_tries_gt_zero(comp_scores, 3, 3, max_value, cpu_device))]))

    assert np.allclose(res.data.numpy(), gt)

def test_no_offending_scores():
    pos = torch.FloatTensor([1, 1])
    neg = torch.FloatTensor([[-1, -1, -1],[-1, -1, -1]])
    res = warp_loss(pos.view(-1, 1), neg, num_labels, cpu_device)
    
    assert res == 0

test_num_tries()
test_ground_truth()
test_no_offending_scores()