# Multitask learning

Hydranet in Tesla is a multitask learning example. It is extremely hard to train the beast and let multiple teams to work in the same large network.

As a subset of the beast network, prediction and behavior are naturally to be trained together, since many of self driving engineers (including me) believes these two tasks are deeply corelated and can be solved more efficiently when jointly modeling them. The general concept is to encode the environment by some backbone network (e.g. resnet or densenet) and then have multiple heads for prediction and decision tasks.

This ends up with the so-called hard sharing network.

![](../assets/multitask/multitask.png)

This network is difficult to train well due to:

1. There are bunch of heads to be either classification or regression tasks. Some are easier, some are harder. Balance of training resource is critical to reach better minimal.
2. The dataset can be in different scale and noise level for each task.
3. In practice, usually the baseline is the single task network. It means you can have a good baseline of prediction or decision task. However, you can not simply freeze the baseline as the main network and train the another task as a head. Because sometimes the input domain are not exactly the same. For example, route info is needed in decision task, but it is not required in prediction task. A workaround is to add the route info after the backbone and encode it again in cnn. This is not the elegant solution.

The 1st problem usually is tackled from loss design and gradient design. The oversample of small number dataset can push some task's train dataset into the same scale as others.


## Weighted sum of loss

The naive one is:
$$
Loss = \sum w L
$$
$w$ of each task is manully setted by engineer. Usually the weight is selected to scale each task loss to the same level.

## [Uncertainty aware loss](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kendall_Multi-Task_Learning_Using_CVPR_2018_paper.pdf)

This loss design makes the weight of each loss term as an optimization variable. For each of the task, there is a variance learning weight. The final loss is:

$$
Loss = \sum_i \frac{L_i}{\sigma_i^2} + log \prod\sigma_i^2
$$

The variance term belongs to [homoscedastic uncertianty](https://en.wikipedia.org/wiki/Homoscedasticity), which is aleatoric uncertianty (statistical uncertainty). This uncertainty is constant for different input data, but varies with different tasks.

When the task has high uncertainty, its loss will be degraded. Thus, the model will learn the low-uncertain tasks first.

In [3]:
import torch
import torch.nn as nn

class LossUnc(nn.Module):
    def __init__(self):
        self.log_vars = []
        self.task_name = ["task1", "task2"]
        self.t1_loss = nn.CrossEntropyLoss()
        self.t2_loss = nn.MSELoss()

        for t in self.task_name:
            self.log_vars.append(nn.Parameter(torch.zeros(1), requires_grad=True))
            self.register_parameter(t, self.log_vars[-1])

    def forward(self, t1_pred, t2_pred, t1_y, t2_y):
        t1_loss = self.t1_loss(t1_pred, t1_y)
        t2_loss = self.t2_loss(t2_pred ,t2_y)

        total_loss = 0
        for log_var, loss in zip(self.log_vars, [t1_loss, t2_loss]):
            w = torch.exp(log_var)
            total_loss += (1.0 / w)*loss + log_var
        return total_loss

## [Dynamic task prioritization](https://openaccess.thecvf.com/content_ECCV_2018/papers/Michelle_Guo_Focus_on_the_ECCV_2018_paper.pdf)

For each task, there is a kpi to track how good this task has learned so far. For example, for classification tasks, the kpi ($k_i$) can be the accuracy; for trajectory prediction tasks, the kpi can be the top1 hit rate. Kpi must be in range [0,1], where 1 means perfect learning. The intuition is that the model should learn the difficult tasks first. In another word, tasks with low kpi should be learned with more efforts. Focal loss like term is used to reflect this idea:

$$
k_i = \alpha k_{i-1} + (1-\alpha) k_i \\
w_i = -(1-k_i)^\gamma k_i \\
Loss = \sum w_i L_i
$$

- kpi has to be calculated in each iteration, and updated in the moving average fashion for smoothness.
- kpi usually doesn't need to be learnable parameters.
- for distributed learning, kpi calculation can be in local gpu without distribution communication (`torch.all_reduce()`).
- in the focal-loss like weight calculation, increasing $\gamma$ if we want to focus more on hard negative samples.

In [4]:
def focal_loss(kpi, gamma=1):
    return -(1.0-kpi)**gamma*torch.log(kpi+1e-8)

class LossDTP(nn.Module):
    def __init__(self):
        self.log_vars = []
        self.task_name = ["task1", "task2"]
        self.t1_loss = nn.CrossEntropyLoss()
        self.t2_loss = nn.MSELoss()

        self.kpis= [0.0 for _ in self.task_name]
        self.alpha = 0.5
    
    @torch.no_grad()
    def metric(self, pred, gt):
        pass

    @torch.no_grad()
    def update_kpi(self, t1_pred, t2_pred, t1_y, t2_y):
        kpi1 = self.metric1(t1_pred, t1_y)
        kpi2 = self.metric2(t2_pred, t2_y)

        for i, cur_kpi in enumerate([kpi1, kpi2]):
            self.kpis[i] = self.alpha*cur_kpi+(1-self.alpha)*self.kpis[i]

    def forward(self, t1_pred, t2_pred, t1_y, t2_y):
        t1_loss = self.t1_loss(t1_pred, t1_y)
        t2_loss = self.t2_loss(t2_pred ,t2_y)

        self.update_kpi(t1_pred, t2_pred, t1_y, t2_y)

        total_loss = 0
        for kpi, loss in zip(self.kpis, [t1_loss, t2_loss]):
            w = focal_loss(kpi, 1)
            total_loss += w*loss
        return total_loss

## Summary

In practice, hand-picked fixed weights performs better than the uncertainty aware method. DTS is worth to try for multiple tasks.