In [1]:
import torch
import torch.nn.functional as F

torch.manual_seed(13);

## Hierarchy

Ontologies can be represented using a tree data structure, where the leaves (nodes at level 0) represent the most specific classes. A mapping can be defined between the nodes at level 0 and any other nodes at level l.

One way to define such a mapping is to use a list of length C (where C is the number of finer classes or leaves), consisting of integer numbers. The integer value represents the node at level L, while the corresponding list-index represents the finer class. Multiple levels of hierarchy can be stacked to form a *"Hierarchy Tensor"* (an (L,C)-tensor). With this implementation, we can obtain the coarser class at level `l` of a finer class `c` by using `hierarchy[l, c]`.

In [2]:
# Hierarchy levels (L), Batch size (B), Number of classes (C)
L, B, C = 2, 4, 6

# outputs of the model (i.e. logits)
outputs = torch.rand(B, C) * 10 - 5

# targets (i.e ground labels)
targets = torch.randint(C, (B,)).long()

outputs, targets

(tensor([[-4.0818, -0.2062,  3.1055, -4.8489, -4.8470,  1.0357],
         [-2.6820,  3.6329,  4.8588, -3.0249, -4.1699, -0.7465],
         [ 4.1487, -0.2010,  0.3479, -2.3053, -2.4698, -1.6100],
         [ 3.3667, -3.7112,  4.6935, -0.5049, -0.9694,  3.2024]]),
 tensor([2, 1, 5, 3]))

```
      _________0_________     ------ level-2
     /         |         \
  __0__      __1__        2   ------ level-1
 /     \    /  |  \       |
1       4  2   0   5      3   ------ level-0

```

In [3]:
hierarchy = torch.Tensor(
    [
        [0, 1, 2, 3, 4, 5],  # level 0
        [1, 0, 1, 2, 0, 1],  # level 1
        # Add here other other levels ... 
        # but do not include the root level.
    ]
)

assert hierarchy.shape == (L, C)

## Hierarchical Hot Encoding

When it comes to losses and metrics, it's helpful to have a multi-hot encoding of the `targets` ((B,)-tensor) at level `l`. This can be achieved using the *"Encoding Tensor"* (an (L,C,C)-tensor): `encoding[l, targets]`, which is generated from the Hierarchy Tensor. Note that `encoding[0, targets]` is the same as the usual one-hot encoding.

The encoding tensor has a data type of bool, but printing out the integer version can make it easier to understand.

In [4]:
encoding = torch.zeros((L, C, C)).bool()

for l, level in enumerate(hierarchy):
    for i, label1 in enumerate(level):
        for j, label2 in enumerate(level):
            if label1 == label2:
                encoding[l, i, j] = True

encoding.long()

tensor([[[1, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0],
         [0, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0, 1]],

        [[1, 0, 1, 0, 0, 1],
         [0, 1, 0, 0, 1, 0],
         [1, 0, 1, 0, 0, 1],
         [0, 0, 0, 1, 0, 0],
         [0, 1, 0, 0, 1, 0],
         [1, 0, 1, 0, 0, 1]]])

```
      _________0_________     ------ level-2
     /         |         \
  __0__      __1__        2   ------ level-1
 /     \    /  |  \       |
1       4  2   0   5      3   ------ level-0

```

In [5]:
targets, encoding[1, targets].long()

(tensor([2, 1, 5, 3]),
 tensor([[1, 0, 1, 0, 0, 1],
         [0, 1, 0, 0, 1, 0],
         [1, 0, 1, 0, 0, 1],
         [0, 0, 0, 1, 0, 0]]))

In [6]:
assert (F.one_hot(targets, C) == encoding[0, targets]).all()

targets, encoding[0, targets].long()

(tensor([2, 1, 5, 3]),
 tensor([[0, 0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0, 0]]))

## Lower Common Anchestor (LCA)
The *"Lower Common Ancestor"* (LCA) is a function that takes two leaves (c1 and c2) and returns the distance to their common ancestor. In this implementation, the LCA function is represented by a C x C symmetric matrix: `lca[c1, c2]`. The general version of LCA function takes two generic nodes as input.

LCA can be used to define hierarchical metrics, such as the "hierarchical distance of a mistake" presented in the paper by \[Ber+20\].

In [7]:
lca = torch.full((C, C), L)

for level in hierarchy:
    for row, coarse in zip(lca, level):
        for index, value in enumerate(level):
            if coarse == value:
                row[index] -= 1
lca

tensor([[0, 2, 1, 2, 2, 1],
        [2, 0, 2, 2, 1, 2],
        [1, 2, 0, 2, 2, 1],
        [2, 2, 2, 0, 2, 2],
        [2, 1, 2, 2, 0, 2],
        [1, 2, 1, 2, 2, 0]])

```
          _________0_________     ------ level-2
         /         |         \
      __0__      __1__        2   ------ level-1
     /     \    /  |  \       |
    1       4  2   0   5      3   ------ level-0
 
   [2]     [2][0] [1] [1]    [2]  ------ lca(2, .) example
```

In [8]:
preds = outputs.argmax(dim=-1)

preds, targets, lca[preds, targets]

(tensor([2, 2, 0, 2]), tensor([2, 1, 5, 3]), tensor([0, 2, 1, 2]))

In [9]:
# Example of lca metric: "hierarchical distance of a mistake" [Ber+20]
misclassified = targets.numel() - (preds == targets).sum()
distance = lca[preds, targets].sum().float()
metric = distance / misclassified
metric

tensor(1.6667)