Like [@laevatein](https://www.kaggle.com/laevatein/tweat-the-loss-function-a-bit), I also wanted to give some loss to my model, to make it understand the gap between start/end index.  

So I was surprised when I found the similar thought in [@laevatein](https://www.kaggle.com/laevatein/tweat-the-loss-function-a-bit) his cool work.
https://www.kaggle.com/laevatein/tweat-the-loss-function-a-bit

I designed a loss function for give 'distance loss' to my model  

### **Distance** btw. prediction's start~end and ground truth's start~end   
- if prediction's length is too different with ground truth's, then give it panelty(~infinite ∞)  
- if both are similar, then near zero panelty


It might be not perfect of course but, I just share my one and looking forward to improving it by someone's feedback.

### Updated
- updated the one_hot creation codes in 'dist_loss()' by using torch.nn.functional.one_hot ([@Yu Kang](http://https://www.kaggle.com/karlyukang)'s feedback)

![](https://user-images.githubusercontent.com/8045508/80869539-22a2a680-8cdc-11ea-85be-d2af6cb7babb.jpeg)

lineary increased values is 'linear_func' in the code

## Distance Loss

In [None]:
import numpy as np
import torch

In [None]:
def dist_between(start_logits, end_logits, device='cpu', max_seq_len=128):
    """get dist btw. pred & ground_truth"""

    linear_func = torch.tensor(np.linspace(0, 1, max_seq_len, endpoint=False), requires_grad=False)
    linear_func = linear_func.to(device)

    start_pos = (start_logits*linear_func).sum(axis=1)
    end_pos = (end_logits*linear_func).sum(axis=1)

    diff = end_pos-start_pos

    return diff.sum(axis=0)/diff.size(0)


def dist_loss(start_logits, end_logits, start_positions, end_positions, device='cpu', max_seq_len=128, scale=1):
    """calculate distance loss between prediction's length & GT's length
    
    Input
    - start_logits ; shape (batch, max_seq_len{128})
        - logits for start index
    - end_logits
        - logits for end index
    - start_positions ; shape (batch, 1)
        - start index for GT
    - end_positions
        - end index for GT
    """
    start_logits = torch.nn.Softmax(1)(start_logits) # shape ; (batch, max_seq_len)
    end_logits = torch.nn.Softmax(1)(end_logits)
    
    start_one_hot = torch.nn.functional.one_hot(start_positions, num_classes=max_seq_len).to(device)
    end_one_hot = torch.nn.functional.one_hot(end_positions, num_classes=max_seq_len).to(device)
    
    pred_dist = dist_between(start_logits, end_logits, device, max_seq_len)
    gt_dist = dist_between(start_one_hot, end_one_hot, device, max_seq_len) # always positive
    diff = (gt_dist-pred_dist)

    rev_diff_squared = 1-torch.sqrt(diff*diff) # as diff is smaller, make it get closer to the one
    loss = -torch.log(rev_diff_squared) # by using negative log function, if argument is near zero -> inifinite, near one -> zero

    return loss*scale


### len(prediction) < len(GT)
- Prediction ; approximately 1~3
- GT ; 1~8

In [None]:
start_logits = torch.zeros(10)
start_logits[0] = 1
start_logits[1] = 8
start_logits[2] = 1

end_logits = torch.zeros(10)
end_logits[2] = 2
end_logits[3] = 6
end_logits[4] = 2

start_pos = torch.tensor(1)
end_pos = torch.tensor(8)

dist_loss(
    torch.unsqueeze(start_logits, 0),
    torch.unsqueeze(end_logits, 0),
    torch.unsqueeze(start_pos, 0),
    torch.unsqueeze(end_pos, 0),
    max_seq_len=10,
)

### len(prediction) > len(GT)
- Prediction ; approximately 1~8
- GT ; 1~1

In [None]:
start_logits = torch.zeros(10)
start_logits[0] = 1
start_logits[1] = 8
start_logits[2] = 1

end_logits = torch.zeros(10)
end_logits[7] = 1
end_logits[8] = 8
end_logits[9] = 1

start_pos = torch.tensor(1)
end_pos = torch.tensor(1)

dist_loss(
    torch.unsqueeze(start_logits, 0),
    torch.unsqueeze(end_logits, 0),
    torch.unsqueeze(start_pos, 0),
    torch.unsqueeze(end_pos, 0),
    max_seq_len=10,
)

### len(prediction) ~= len(GT)
- Prediction ; approximately 1~8
- GT ; 1~8

In [None]:
start_logits = torch.zeros(10)
start_logits[0] = 1
start_logits[1] = 8
start_logits[2] = 1

end_logits = torch.zeros(10)
end_logits[7] = 1
end_logits[8] = 8
end_logits[9] = 1

start_pos = torch.tensor(1)
end_pos = torch.tensor(8)

dist_loss(
    torch.unsqueeze(start_logits, 0),
    torch.unsqueeze(end_logits, 0),
    torch.unsqueeze(start_pos, 0),
    torch.unsqueeze(end_pos, 0),
    max_seq_len=10,
)

## Notice
- distance for the same length decreases, as **max_seq_len** is increasing

## Example

```
start_logits, end_logits = model(token_ids,
                                 token_type_ids=token_type_ids,
                                 attention_mask=attention_mask)

start_loss = torch.nn.CrossEntropyLoss()(start_logits, start_positions)
end_loss = torch.nn.CrossEntropyLoss()(end_logits, end_positions)

idx_loss = (start_loss+end_loss)

dist_loss = utils.dist_loss(
    start_logits, end_logits,
    start_positions, end_positions,
    device, cfg.MAX_SEQ_LEN) 

total_loss = idx_loss + dist_loss
```

## Test with...
- model ; RoBERTa
- loss ; crossEntropyLoss(with start_logits, end_logits) + distanceLoss
- max_seq_len ; 128
- batch size ; 128
- learning rate ; 9e-5
- epoch ; 3
- scheduler ; cosine warmup scheduler
- early stopping with validation jaccard score & patience=3
     - checked 4 times in each epochs
- 5-fold using Stratified K-fold (sklearn) with random seed (293984)


## Result

### Not Using Distance Loss ;
    - 1 fold ; 0.7007
    - 2 fold ; 0.7088
    - 3 fold ; 0.7113
    - 4 fold ; 0.7070
    - 5 fold ; 0.7041
    - avg ; 0.7065
### Using Distance Loss ; 
    - 1 fold ; 0.7061
    - 2 fold ; 0.7128
    - 3 fold ; 0.7139
    - 4 fold ; 0.7043
    - 5 fold ; 0.7086
    - avg ; 0.7091