# CORAL Implementation Recipe

This notebook provides a brief overview of the relevant parts of CORAL that distinguish from a "regular" deep neural network for classification. The purpose of this overview is to provide a succinct resource that may help other researchers in porting the CORAL framework to other, non-PyTorch code bases (such as TensorFlow, Keras, MXnet, etc.). 

## 1) Implement a function that converts class labels

- In contrast to regular cross entropy-based classification approaches, CORAL operates on binary tasks rather than integer class labels (or one-hot encoding representations thereof); hence, we have to convert class labels into the respective representation, which is illustrated here. 
- To provide an example, suppose you have a dataset consisting of 5 classes; consequently, the class labels are 0, 1, 2, 3, and 4.
- The following `label_to_levels` function converts class labels into the binary task representation required by CORAL, we call them "levels:"

In [24]:
import torch


def label_to_levels(label, num_classes):
    levels = [1]*label + [0]*(num_classes - 1 - label)
    levels = torch.tensor(levels, dtype=torch.float32)
    return levels

- To continue with the example, assume we have a dataset of 3 training examples with class labels 2, 1, 4
- The converted labels would look as follows:

In [32]:
NUM_CLASSES = 5

levels = []
class_labels = [2, 1, 4]

for label in class_labels:
    levels_from_label = label_to_levels(label, num_classes=NUM_CLASSES)
    levels.append(levels_from_label)

levels = torch.stack(levels)
levels

tensor([[1., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 1., 1., 1.]])

## 2) Implement the loss function

As outlined in the paper, the loss function is defined as

$$
\begin{aligned}
L(\mathbf{W}, \mathbf{b})=& \\
&-\sum_{i=1}^{N} \sum_{k=1}^{K-1} \lambda^{(k)}\left[\log \left(s\left(g\left(\mathbf{x}_{i}, \mathbf{W}\right)+b_{k}\right)\right) y_{i}^{(k)}\right.\\
&\left.+\log \left(1-s\left(g\left(\mathbf{x}_{i}, \mathbf{W}\right)+b_{k}\right)\right)\left(1-y_{i}^{(k)}\right)\right]
\end{aligned}
$$

- In the paper, we used a uniform task importance weight $\lambda$ (this means, all binary tasks were treated equally); this can be achieved by using a vectors of 1's as the task importance weights.

In [26]:
importance_weights = torch.ones(NUM_CLASSES-1, dtype=torch.float)

- The loss function itself, based on the equation above, can be implemented as follows:

In [27]:
import torch.nn.functional as F


def loss_fn1(logits, levels, imp):
    val =  -torch.sum((torch.log(torch.sigmoid(logits))*levels + 
             torch.log(1 - torch.sigmoid(logits))*(1-levels))*imp,
           dim=1)
    return torch.mean(val)    

- To apply it to a concrete example, we use the previous "levels" and some made-up logit values (these logits values would be the neural network outputs).
- Note that the rows represent the training examples, whereas the columns represent the logit value for each binary task:

In [28]:
logits = torch.tensor([[2.1, 1.8, -2.1, -1.8],
                       [1.9, -1., -1.5, -1.3],
                       [1.9, 1.8, 1.7, 1.6]])

In [29]:
loss_fn1(logits=logits, 
         levels=levels,
         imp=importance_weights)

tensor(0.6920)

In practice, we found the loss function can be numerically more stable if we rewrite 


(1) 

```python
torch.log(torch.sigmoid(logits))*levels
```

as 

```python 
F.logsigmoid(logits)*levels
```

and (2)

```python 
torch.log(1 - torch.sigmoid(logits))*(1-levels)
```

as 

```python 
(F.logsigmoid(logits) - logits)*(1-levels)
```

Note that (2) if valid since

$$
\begin{aligned}
log\bigg(\frac{e^x}{1+e^x}\bigg) - x &= log\bigg(\frac{e^x}{1+e^x}\bigg) - log(e^x)\\
&= log\bigg(\frac{1}{1+e^x}\bigg)\\
&= log\bigg(1-\frac{e^x}{1+e^x}\bigg)
\end{aligned}
$$

- Hence, in practice, we recommend using the following loss function (which produces the same results as the one outlined above):

In [30]:
def loss_fn2(logits, levels, imp):
    val = (-torch.sum((F.logsigmoid(logits)*levels
                      + (F.logsigmoid(logits) - logits)*(1-levels))*imp,
           dim=1))
    return torch.mean(val)

In [31]:
loss_fn2(logits=logits, 
         levels=levels,
         imp=importance_weights)

tensor(0.6920)

## 3) Modify the neural network architecture

The modification that has to be made to an existing deep neural network classifier is relatively simple as it only affects the last layer (i.e., output layer). In particular, the last fully connected layer has to be changed.

In PyTorch, this means 

(1) changing the last fully connected layer

```python
...
self.fc = nn.Linear(input_size, num_classes)
```

to

```python
...
self.fc = nn.Linear(input_size, 1, bias=False)
self.linear_1_bias = nn.Parameter(torch.zeros(self.num_classes-1).float())
```

(2) and changing the forward pass from

```python
...
logits = self.fc(x)
probas = F.softmax(logits, dim=1)
return logits, probas
```
        
to

```python
logits = self.fc(x) + self.linear_1_bias
probas = torch.sigmoid(logits)
return logits, probas
```

## 4) Evaluate the neural network

Computing performance metrics such as the mean absolute error, or simply obtaining class labels, requires a small modification. Whereas in cross entropy-based classifiers, we obtain the class labels via

```python
logits, probas = model(features)
_, predicted_labels = torch.max(probas, 1)
```

we can change these lines to

```python
logits, probas = model(features)
predict_levels = probas > 0.5
predicted_labels = torch.sum(predict_levels, dim=1)
```

in CORAL.