# Multilabel classification

In **multilabel classification**, each instance can be assigned multiple labels simultaneously. This is different from multiclass classification, where each instance is assigned to one and only one class from a set of classes.

This notebook shows how to use torchTextClassifiers to perform multilabel classification.

## Ragged-lists approach

In [None]:
import numpy as np
import torch

from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
from torchTextClassifiers.dataset import TextClassificationDataset
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
from torchTextClassifiers.model.components import (
    AttentionConfig,
    CategoricalVariableNet,
    ClassificationHead,
    TextEmbedder,
    TextEmbedderConfig,
)
from torchTextClassifiers.tokenizers import HuggingFaceTokenizer

%load_ext autoreload
%autoreload 2

Let's use fake data.

Look at `labels`: it is a list of lists, where each inner list contains the labels for the corresponding instance.

We're indeed in a multilabel classification setting, where each instance can have multiple labels.

In [None]:
sample_text_data = [
    "This is a positive example",
    "This is a negative example",
    "Another positive case",
    "Another negative case",
    "Good example here",
    "Bad example here",
]

labels = [[0, 1, 5], [0, 4], [1, 5], [0, 1, 4], [1, 5], [0]]

Note that `labels` is not a nice object to manipulate: each inner list has different lengths. You can not convert it to a tensor or a numpy array directly.

This is called a *jagged array* or *ragged array*.

Yet, you do not need to change anything: torchTextClassifiers can handle this kind of data directly.

In [None]:
labels = np.array(labels)  # This does not work !

Let's import a pre-trained tokenizer from HuggingFace.

In [None]:
tokenizer = HuggingFaceTokenizer.load_from_pretrained(
    "google-bert/bert-base-uncased", output_dim=126
)

And create our input numpy array.

In [None]:
X = np.array(
    sample_text_data
)

print(X.shape)

Y = labels # only for the sake of clarity, but it remains a ragged array here

We initialize a very simple model, no categorical features, no attention, just text input and multilabel output.

In this setting, we advise to use `torch.nn.BCEWithLogitsLoss()` as loss function in the training config. 

Each label is treated as a separate (but not independent, because we output the joint prediction vector) binary classification problem (where we try to estimate the probability of inclusion), whereas in the default setting (multiclass classification) the model uses `torch.nn.CrossEntropyLoss()`, that implies a *competition* among classes.

Note that we won't enforce this change of loss and if you do not specify it, the default loss (CrossEntropyLoss) will be used.

In [None]:
embedding_dim = 96
n_layers = 2
n_head = 4
n_kv_head = n_head
sequence_len = tokenizer.output_dim
num_classes = max(max(label_list) for label_list in labels) + 1

model_config = ModelConfig(
    embedding_dim=embedding_dim,
    num_classes=num_classes,
)

training_config = TrainingConfig(
    lr=1e-3,
    batch_size=4,
    num_epochs=1,
    loss=torch.nn.BCEWithLogitsLoss(),  # change the loss here
)

Here, do not forget to set `ragged_multilabel=True` !

In [None]:
ttc = torchTextClassifiers(
    tokenizer=tokenizer,
    model_config=model_config,
    ragged_multilabel=True,  # This is key !
)

And you can train !

In [None]:
ttc.train(
    X_train=X,
    y_train=Y,
    training_config=training_config,
)

What happens behind the hood, is that we efficiently convert your ragged lists of labels into a binary matrix, where each row corresponds to an instance and each column to a label. A value of 1 indicates the presence of a label for an instance, while 0 indicates its absence: **it is a one-hot version** of your ragged lists.

You can have a look [here](../torchTextClassifiers/dataset/dataset.py#L85).

## One-hot / multidimensional output approach

You can also choose to directly provide a one-hot / multidimensional array as labels.

For each sample, you have a vector of size equal to the number of labels, with 1s and 0s indicating the presence or absence of each label - or float values between 0 and 1, indicating the ground truth probability of each label.

You do not have ragged lists anymore: **set `ragged_multilabel=False`** in the ``ttc`` initialization (it is very important, otherwise it will interpret it as a bag of labels as previously ! - we will throw a warning if we detect that your labels are one-hot encoded while you set `ragged_multilabel=True`, but we won't enforce anything).

Also, convert your labels to a numpy array - it is possible now !

In [None]:
# We put 1s here, but it could be any float value (probabilities...)
labels = [[1., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0.]]
Y = np.array(labels)

In [None]:
ttc = torchTextClassifiers(
    tokenizer=tokenizer,
    model_config=model_config,
)  # We removed the ragged_multilabel flag here, it is False by default

In [None]:
ttc.train(
    X_train=X,
    y_train=Y,
    training_config=training_config,
)

As discussed, you can also put probabilities in `labels`. 

In this case, once again, you can use:

- `torch.nn.BCEWithLogitsLoss()` as loss function in the training config, if you are in a multilabel setting.
- `torch.nn.CrossEntropyLoss()` as loss function in the training config, if you are in a *soft* multiclass setting (i.e. each instance has only one label, but you provide probabilities instead of class indices). Normally, your ground truth probabilities should sum to 1 for each instance in this case.

We won't enforce anything that PyTorch does not enforce, so make sure to choose the right loss function for your task.