# Multi-task Training with Hugging Face Transformers and NLP

### Or: A recipe for multi-task training with Transformers' Trainer and NLP datasets



Hugging Face has been building a lot of exciting new NLP functionality lately. The newly released [NLP](https://github.com/huggingface/nlp) provides a wide coverage of task data sets and metrics, as well as a simple interface for processing and caching the inputs extremely efficiently. They have also recently introduced a [Trainer](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class to the Transformers library that handles all of the training and validation logic.

However, one feature that is not currently supported in Hugging Face's current offerings is *multi-task training*. While there has been some discussion about the best way to support multi-task training ([1](https://github.com/huggingface/transformers/issues/4340), [2](https://github.com/huggingface/nlp/issues/217)), the community has not yet settled on a convention for doing so. Multi-task training has been shown to improve task performance ([1](https://www.aclweb.org/anthology/P19-1441/), [2](https://arxiv.org/abs/1910.10683)) and is a common experimental setting for NLP researchers.

In this Colab notebook, we will show how to use both the new NLP library as well as the Trainer for a **multi-task** training scheme.

So let's get started!

## Library setup

First up, we will install the *NLP* and *Transformers* libraries. 

<font color='red'>**Note: After running the following cell, you will need to restart your runtime for the installation to work properly.**</font>

In [1]:
#!pip install git+https://github.com/huggingface/nlp
!pip install transformers==4.18
!pip install nlp==0.2.0
!pip install datasets



In [2]:
import numpy as np
import torch
import torch.nn as nn
import transformers
import nlp
import logging
import datasets
from datasets import Dataset
logging.basicConfig(level=logging.INFO)

## Fetching our data

To showcase our multi-task functionality, we will choose tasks of different formats:

* STS-B: A two-sentece textual similarity scoring task. (Prediction is a real number between 1 and 5)
* RTE: A two-sentence natural language entailment task. (Prediction is one of two classes)
* Commonsense QA: A multiple-choice question-answering task. (Each example consists of 5 seperate text inputs, prediction is which one of the 5 choices is correct)

In particular, notice that unlike STS-B and RTE, Commonsense QA consists of feeding *multiple* inputs into the transformer model. Many other tasks have weirder formats too, so our setup needs to be flexible enough to accomodate very different kinds of tasks.

Now, actually getting the task data is super simple. We can simply call the `nlp.load_dataset` method, which automatically downloads the data and prepares it for use.

In [3]:
# !git init
# !git remote add origin https://github.com/martiansideofthemoon/style-transfer-paraphrase.git
# !git pull origin master
# !pip install -r requirements.txt
# !pip install --editable .
# !cd fairseq
# !pip install --editable .

In [4]:
# !unzip shakespeare.zip

In [5]:
import pandas as pd

dev = ['dev.input0.txt', 'dev.label', 'dev.paraphrase_250_input0.txt',]
train= ['train.input0.txt', 'train.label', 'train.paraphrase_250_input0.txt',]
test=['test.input0.txt', 'test.label', 'test.paraphrase_250_input0.txt',]

dev_dict = {}
train_dict = {}
test_dict = {}

for phase, phase_dict in [(dev,dev_dict), (train,train_dict), (test,test_dict)]:
  for name in phase:
    with open(f'shakespeare/{name}','r') as f:
      string = f.read()
      datalist = string.split("\n")
      if 'paraphrase' in name:
        phase_dict['translate'] = datalist  
      elif 'input0' in name:
        phase_dict['original'] = datalist
      else:
        phase_dict['labels'] = datalist

In [6]:
from sklearn import preprocessing
le = preprocessing.LabelEncoder()


#########
df = pd.DataFrame(train_dict).replace('', np.nan).dropna()
train_original_translate = df[df['labels'] == 'original'].drop(columns=['labels'])
train_modern_translate = df[df['labels'] == 'modern'].drop(columns=['labels'])

trans = df['translate'].to_frame()
df = df.drop(columns =['translate'])
trans = trans.rename(columns={'translate':'original'})
trans['labels'] = 'translate'
train_classify = pd.concat([df, trans])
train_classify = train_classify.rename(columns={'original':'text'})
le.fit(train_classify['labels'])
train_classify['labels'] = le.transform(train_classify['labels'])

#########
df = pd.DataFrame(test_dict).replace('', np.nan).dropna()
test_original_translate = df[df['labels'] == 'original'].drop(columns=['labels'])
test_modern_translate = df[df['labels'] == 'modern'].drop(columns=['labels'])

trans = df['translate'].to_frame()
df = df.drop(columns =['translate'])
trans = trans.rename(columns={'translate':'original'})
trans['labels'] = 'translate'
test_classify = pd.concat([df, trans])
test_classify = test_classify.rename(columns={'original':'text'})
le.fit(test_classify['labels'])
test_classify['labels'] = le.transform(test_classify['labels'])

#########
df = pd.DataFrame(dev_dict).replace('', np.nan).dropna()
dev_original_translate = df[df['labels'] == 'original'].drop(columns=['labels'])
dev_modern_translate = df[df['labels'] == 'modern'].drop(columns=['labels'])

trans = df['translate'].to_frame()
df = df.drop(columns =['translate'])
trans = trans.rename(columns={'translate':'original'})
trans['labels'] = 'translate'
dev_classify = pd.concat([df, trans])
dev_classify = dev_classify.rename(columns={'original':'text'})
le.fit(dev_classify['labels'])
dev_classify['labels'] = le.transform(dev_classify['labels'])

#########
translate_original_dataset = {'train':Dataset.from_pandas(train_original_translate),
                              'test': Dataset.from_pandas(test_original_translate),
                              'dev': Dataset.from_pandas(dev_original_translate)}

translate_modern_dataset = {'train':Dataset.from_pandas(train_modern_translate),
                              'test': Dataset.from_pandas(test_modern_translate),
                              'dev': Dataset.from_pandas(dev_modern_translate)}

classify_dataset = {'train': Dataset.from_pandas(train_classify),
                    'test': Dataset.from_pandas(test_classify),
                    'dev': Dataset.from_pandas(dev_classify)}


In [7]:
shakespeare_original = {'translate':translate_original_dataset,
                        'classify':classify_dataset}

In [45]:
shakespeare_modern = {'translate':translate_modern_dataset,
                        'classify':classify_dataset}

In [8]:
dataset_dict = {
    "stsb": datasets.load_dataset('glue', name="stsb"),
    "rte": datasets.load_dataset('glue', name="rte"),
    "commonsense_qa": datasets.load_dataset('commonsense_qa'),
}

100%|██████████| 3/3 [00:00<00:00, 775.81it/s]
100%|██████████| 3/3 [00:00<00:00, 1054.29it/s]
100%|██████████| 3/3 [00:00<00:00, 529.43it/s]


We can show one example from each task.

In [9]:
for task_name, dataset in dataset_dict.items():
    print(task_name)
    print(dataset_dict[task_name]["train"][0])
    print()

stsb
{'sentence1': 'A plane is taking off.', 'sentence2': 'An air plane is taking off.', 'label': 5.0, 'idx': 0}

rte
{'sentence1': 'No Weapons of Mass Destruction Found in Iraq Yet.', 'sentence2': 'Weapons of Mass Destruction Found in Iraq.', 'label': 1, 'idx': 0}

commonsense_qa
{'answerKey': 'A', 'question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?', 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid']}}



## Creating a Multi-task Model

Next up, we are going to create a multi-task model. 

Typically, a multi-task model in the age of BERT works by having a shared BERT-style encoder transformer, and different task heads for each task.

![Multi-Task 1](https://drive.google.com/uc?id=1TCdyyoHInbiZtSOUmyJN1miCj1iysygU)

We could try to implement this directly in code, but there are two downsides to this approach:

1. Hugging Face's Transformers has implementations for single-task models, but not modular task heads. This means we will need to do a lot of our own leg work to write our own task heads.
2. This format assumes that the input is processed the same way in the encoder for every task. Already, Commonsense QA is problematic for this approach, since it requires the encoder to process *multiple* input sequences for a single example. Other tasks may similarly break this abstraction.

Instead, we are going to do something **radically different**. We are going to create separate models for each task, but we are going make them share the same encoder. 

![Multi-Task 2](https://drive.google.com/uc?id=1xmghPPO5RC-TnpYP4_PpZ-TRfJF33S6p)

This will serve the same goal as having the encoder be jointly trained across multiple tasks, but still retain the independent implementations of each model. As such, we can use the existing task-model implementations in Transformers, such as `RobertaForSequenceClassification` and `RobertaForMultipleChoice`.

Importantly, the shared encoder ensures that during training, all updates will update the same encoder weighs, and also **does not consume any additional GPU memory**.

First, we define our `MultitaskModel` class:

In [10]:
class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        """
        Setting MultitaskModel up as a PretrainedModel allows us
        to take better advantage of Trainer features
        """
        super().__init__(transformers.PretrainedConfig())

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        """
        This creates a MultitaskModel using the model class and config objects
        from single-task models. 

        We do this by creating each single-task model, and having them share
        the same encoder transformer.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = getattr(model, cls.get_encoder_attr_name(model))
            else:
                setattr(model, cls.get_encoder_attr_name(model), shared_encoder)
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)

    @classmethod
    def get_encoder_attr_name(cls, model):
        """
        The encoder transformer is named differently in each model "architecture".
        This method lets us get the name of the encoder attribute
        """
        model_class_name = model.__class__.__name__
        if model_class_name.startswith("Bert"):
            return "bert"
        elif model_class_name.startswith("Roberta"):
            return "roberta"
        elif model_class_name.startswith("Albert"):
            return "albert"
        elif model_class_name.startswith("GPT2"):
            return "transformer"
        else:
            raise KeyError(f"Add support for new model {model_class_name}")

    def forward(self, task_name, **kwargs):
        return self.taskmodels_dict[task_name](**kwargs)

As described above, the `MultitaskModel` class consists of only two components - the shared "encoder", a dictionary to the individual task models. Now, we can simply create the corresponding task models by supplying the invidual model classes and model configs. We will use Transformers' AutoModels to further automate the choice of model class given a model architecture (in our case, let's use `roberta-base`).

In [11]:
model = transformers.AutoConfig.from_pretrained("gpt2-large")

In [12]:
del model

In [13]:
# model_name = "roberta-base"
# multitask_model = MultitaskModel.create(
#     model_name=model_name,
#     model_type_dict={
#         "stsb": transformers.AutoModelForSequenceClassification,
#         "rte": transformers.AutoModelForSequenceClassification,
#         "commonsense_qa": transformers.AutoModelForMultipleChoice,
#     },
#     model_config_dict={
#         "stsb": transformers.AutoConfig.from_pretrained(model_name, num_labels=1),
#         "rte": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
#         "commonsense_qa": transformers.AutoConfig.from_pretrained(model_name),
#     },
# )

special_tokens = ['<bos>','<eos>','<pad>','<cls>']

tokenizer_args = {
    "bos_token": '<bos>',
    'eos_token': '<eos>',
    'pad_token': '<pad>',
    'cls_token': '<cls>',
}

model_name = "gpt2-large"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name,**tokenizer_args)
model_config_args = {
    "bos_token_id": tokenizer.bos_token_id,
    'eos_token_id': tokenizer.eos_token_id,
    'pad_token_id': tokenizer.pad_token_id,
    'cls_token_id': tokenizer.cls_token_id,
}

multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "translate": transformers.AutoModelForCausalLM,
        "classify": transformers.AutoModelForSequenceClassification
    },
    model_config_dict={
        "translate": transformers.AutoConfig.from_pretrained(model_name, **model_config_args),
        "classify": transformers.AutoConfig.from_pretrained(model_name, num_labels=3,**model_config_args),
    },
)

for name, model in multitask_model.taskmodels_dict.items():
  model.resize_token_embeddings(len(tokenizer))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2-large and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


To confirm that all three task-models use the same encoder, we can check the data pointers of the respective encoders. In this case, we'll check that the word embeddings in each model all point to the same memory location.

In [14]:
# if model_name.startswith("roberta-"):
#     print(multitask_model.encoder.transformer.state_dict()['h.28.mlp.c_proj.weight'].data_ptr())
#     print(multitask_model.taskmodels_dict["stsb"].roberta.embeddings.word_embeddings.weight.data_ptr())
#     print(multitask_model.taskmodels_dict["rte"].roberta.embeddings.word_embeddings.weight.data_ptr())
#     print(multitask_model.taskmodels_dict["commonsense_qa"].roberta.embeddings.word_embeddings.weight.data_ptr())
# else:
#     print("Exercise for the reader: add a check for other model architectures =)")


if model_name.startswith("gpt2-"):
    print(multitask_model.encoder.state_dict()['h.28.mlp.c_proj.weight'].data_ptr())
    print(multitask_model.taskmodels_dict["translate"].state_dict()['transformer.h.28.mlp.c_proj.weight'].data_ptr())
    print(multitask_model.taskmodels_dict["classify"].state_dict()['transformer.h.28.mlp.c_proj.weight'].data_ptr())
else:
    print("Exercise for the reader: add a check for other model architectures =)")

140401452068928
140401452068928
140401452068928


## Processing our task data

We have created a dictionary of NLP datasets above, but we need to do a little more work to convert the respective task data into model inputs.

We'll start by first getting the tokenizer corresponding to our model.

Next, we'll write some short functions to convert from raw text to tokenized text inputs. 

* Both STS-B and RTE and two-sentence input tasks, so we will concatenate them with the corresponding special tokens. (The tokenizer's `batch_encode_plus` method handles this for us.) So, the input might look like: 

```
['<s>', 'This', 'is', 'my', 'premise', '.', '</s>', '</s>', 'This', 'is', 'my', 'hypothesis', '.', '</s>']
```

* CommonsenseQA, is a multiple choice task. A single example consists of a question, a five possible answer choices. We will feed the model inputs concatenated like `QUESTION + CHOICE_1`, `QUESTION + CHOICE_2` and so on. 

In [33]:
max_length = 128

def convert_to_translate_features(example_batch):
  inputs = example_batch['translate'] + '<bos>' + example_batch['original'] + '<eos>'
  features = tokenizer(inputs, padding='max_length', max_length=max_length, truncation=True)
  # features = tokenizer(inputs)
  labels = features['input_ids'].copy()
  token_type_ids = [1] * len(features['input_ids'])
  reach_bos = False
  reach_eos = False
  for idx, id in enumerate(labels):
    if id == tokenizer.bos_token_id:
      reach_bos = True
      labels[idx] = -100
    if id == tokenizer.eos_token_id:
      reach_eos = True
    if (not reach_bos) or (reach_bos and reach_eos and id != tokenizer.eos_token_id):
      labels[idx] = -100
  features['labels'] = labels
  features['token_type_ids'] = token_type_ids
  return features

def convert_to_classify_features(example_batch):
  inputs = [sen + '<cls>' for sen in example_batch['text']]
  features = tokenizer(inputs, padding='max_length', max_length=max_length, truncation=True)
  # features = tokenizer(inputs)
  features['labels'] = example_batch['labels']
  return features

translate_mapping_args = {'batched':False}
classify_mapping_args = {'batched':True}

convert_func_dict = {
    "translate": (convert_to_translate_features, translate_mapping_args),
    "classify": (convert_to_classify_features, classify_mapping_args)
}

# def convert_to_stsb_features(example_batch):
#     inputs = list(zip(example_batch['sentence1'], example_batch['sentence2']))
#     features = tokenizer.batch_encode_plus(
#         inputs, max_length=max_length, pad_to_max_length=True
#     )
#     features["labels"] = example_batch["label"]
#     return features

# def convert_to_rte_features(example_batch):
#     inputs = list(zip(example_batch['sentence1'], example_batch['sentence2']))
#     features = tokenizer.batch_encode_plus(
#         inputs, max_length=max_length, pad_to_max_length=True
#     )
#     features["labels"] = example_batch["label"]
#     return features

# def convert_to_commonsense_qa_features(example_batch):
#     num_examples = len(example_batch["question"])
#     num_choices = len(example_batch["choices"][0]["text"])
#     features = {}
#     for example_i in range(num_examples):
#         choices_inputs = tokenizer.batch_encode_plus(
#             list(zip(
#                 [example_batch["question"][example_i]] * num_choices,
#                 example_batch["choices"][example_i]["text"],
#             )),
#             max_length=max_length, pad_to_max_length=True,
#         )
#         for k, v in choices_inputs.items():
#             if k not in features:
#                 features[k] = []
#             features[k].append(v)
#     labels2id = {char: i for i, char in enumerate("ABCDE")}
#     # Dummy answers for test
#     if example_batch["answerKey"][0]:
#         features["labels"] = [labels2id[ans] for ans in example_batch["answerKey"]]
#     else:
#         features["labels"] = [0] * num_examples    
#     return features

# convert_func_dict = {
#     "stsb": convert_to_stsb_features,
#     "rte": convert_to_rte_features,
#     "commonsense_qa": convert_to_commonsense_qa_features,
# }

Now that we have defined the above functions, we can use `dataset.map` method available in the NLP library to apply the functions over our entire datasets. The NLP library that handles the mapping efficiently and caches the features.

In [34]:
# columns_dict = {
#     "stsb": ['input_ids', 'attention_mask', 'labels'],
#     "rte": ['input_ids', 'attention_mask', 'labels'],
#     "commonsense_qa": ['input_ids', 'attention_mask', 'labels'],
# }

columns_dict = {
    "translate": ['input_ids', 'attention_mask', 'labels', 'token_type_ids'],
    "classify": ['input_ids', 'attention_mask', 'labels'],
}


features_dict = {}
for task_name, dataset in shakespeare_original.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        func, args = convert_func_dict[task_name]
        features_dict[task_name][phase] = phase_dataset.map(
            func,
            **args,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

18395ex [00:06, 3046.98ex/s]


translate train 18395 18395
translate train 18395 18395


1462ex [00:00, 3164.78ex/s]


translate test 1462 1462
translate test 1462 1462


1218ex [00:00, 3154.89ex/s]


translate dev 1218 1218
translate dev 1218 1218


100%|██████████| 74/74 [00:02<00:00, 35.75ba/s]


classify train 73580 73580
classify train 73580 73580


100%|██████████| 6/6 [00:00<00:00, 41.46ba/s]


classify test 5848 5848
classify test 5848 5848


100%|██████████| 5/5 [00:00<00:00, 37.96ba/s]

classify dev 4872 4872
classify dev 4872 4872





In [46]:
columns_dict = {
    "translate": ['input_ids', 'attention_mask', 'labels', 'token_type_ids'],
    "classify": ['input_ids', 'attention_mask', 'labels'],
}


features_dict_modern = {}
for task_name, dataset in shakespeare_modern.items():
    features_dict_modern[task_name] = {}
    for phase, phase_dataset in dataset.items():
        func, args = convert_func_dict[task_name]
        features_dict_modern[task_name][phase] = phase_dataset.map(
            func,
            **args,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict_modern[task_name][phase]))
        features_dict_modern[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict_modern[task_name][phase]))

18395ex [00:06, 3054.67ex/s]


translate train 18395 18395
translate train 18395 18395


1462ex [00:00, 3148.40ex/s]


translate test 1462 1462
translate test 1462 1462


1218ex [00:00, 3116.16ex/s]


translate dev 1218 1218
translate dev 1218 1218


100%|██████████| 74/74 [00:02<00:00, 34.58ba/s]


classify train 73580 73580
classify train 73580 73580


100%|██████████| 6/6 [00:00<00:00, 41.13ba/s]


classify test 5848 5848
classify test 5848 5848


100%|██████████| 5/5 [00:00<00:00, 41.18ba/s]

classify dev 4872 4872
classify dev 4872 4872





In [48]:
import pickle
with open('feature_dict_shakespeare_original.pkl', 'ab') as f:
    pickle.dump(features_dict, f)
with open('feature_dict_shakespeare_modern.pkl', 'ab') as f:
    pickle.dump(features_dict, f)

As a recap:

* We have created our multi-task model by fusing several single-task Transformer models
* We have created a (cached) dictionary of featurized inputs for each of our tasks, using NLP dataset

Next up, we need to 

1. Set up our data loading
2. Set up our Trainer 
3. Start training!

## Preparing a multi-task data loader and Trainer

Setting up a multi-task data loader should be simple in principle - we simply need to sample from multiple single-task data loaders with some probability, and feed each batch to the multi-task model above. Of course, along with each batch, we also need to tell the model what task it is for, so `MultitaskModel` knows to use the right corresponding task-model.

However, because we want to use the built-in `Trainer` class in Transformers, this gets a little tricky, since the `Trainer` expects a single data loader, and expects a very specific format of per-batch data. This slice of code is somewhat of a hack around that constraint. (This can become a lot more streamlined with some tweaks to the Trainer code from the Hugging Face folks =))

We need to define a `MultitaskDataloader` that combines several data loaders into a single "data loader" - not so different from our multi-task model above! This `MultitaskDataloader` should do what we described: sample from different single-task data loaders, and yield a task batch and the corresponding task name (we're going to add the `task_name` to the batch data).

We will also need to override the `get_train_dataloader` method of the `Trainer` to play well with our `MultitaskDataloader`. We do this with a `MultitaskTrainer`.

In [17]:
import dataclasses
from torch.utils.data.dataloader import DataLoader
# from transformers.training_args import is_tpu_available
# from transformers.trainer import get_tpu_sampler
from transformers.data.data_collator import DataCollator, InputDataClass, DataCollatorMixin
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict

# class NLPDataCollator(DataCollator):
class NLPDataCollator(DataCollatorMixin):
    """
    Extending the existing DataCollator to work with NLP dataset batches
    """

    def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        first = features[0]
        if isinstance(first, dict):
          # NLP data sets current works presents features as lists of dictionary
          # (one per example), so we  will adapt the collate_batch logic for that
          if "labels" in first and first["labels"] is not None:
              if first["labels"].dtype == torch.int64:
              #     labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
                labels = torch.stack([f["labels"] for f in features]).long()
              else:
              #     labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
                labels = torch.stack([f["labels"] for f in features]).float()
              batch = {"labels": labels}
          for k, v in first.items():
              if k != "labels" and v is not None and not isinstance(v, str):
                  # batch[k] = torch.stack([ f[k] for f in features])
                  f_list = []
                  for f in features:
                    if type(f[k])==list and len(f[k]) > 1:
                      f_list.append(torch.stack(f[k]))
                    else:
                      f_list.append(f[k])
                  # print(f_list)
                  batch[k] = torch.stack(f_list)
          return batch
        else:
          # otherwise, revert to using the default collate_batch
          return DefaultDataCollator().collate_batch(features)
        return transformers.DefaultDataCollator().collate_batch(features)


class StrIgnoreDevice(str):
    """
    This is a hack. The Trainer is going call .to(device) on every input
    value, but we need to pass in an additional `task_name` string.
    This prevents it from throwing an error
    """
    def to(self, device):
        return self


class DataLoaderWithTaskname:
    """
    Wrapper around a DataLoader to also yield a task name
    """
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.

        We use size-proportional sampling, but you could easily modify this
        to sample from some-other distribution.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    

class MultitaskTrainer(transformers.Trainer):

    def get_single_train_dataloader(self, task_name, train_dataset):
        """
        Create a single-task data loader that also yields task names
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        # if is_tpu_available():
        #     train_sampler = get_tpu_sampler(train_dataset)
        # else:
        #     train_sampler = (
        #         RandomSampler(train_dataset)
        #         if self.args.local_rank == -1
        #         else DistributedSampler(train_dataset)
        #     )
        train_sampler = (
                RandomSampler(train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(train_dataset)
            )

        data_loader = DataLoaderWithTaskname(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )

        # if is_tpu_available():
        #     data_loader = pl.ParallelLoader(
        #         data_loader, [self.args.device]
        #     ).per_device_loader(self.args.device)
        return data_loader

    def get_train_dataloader(self):
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })

## Time to train!

Okay, we have done all the hard work, now it is time for it to pay off. We can now simply create our `MultitaskTrainer`, and start training! 

(This takes about ~45 minutes for me on Colab, but it will depend on the GPU you are allocated.)

In [35]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}

In [42]:
features_dict['translate']['train']['input_ids'][1]

tensor([ 4360,   611,   345,   910,  3738,  1647,   338,  6776,    11,   314,
         1183,  1577,   345,   257,  3869,   290,   257, 43836, 14643,    11,
          393,  2460,   351, 24088,    11,   393,   407,    13, 50257, 11486,
          611, 14210,   910,  3738,  1647,  3160,    11,   318,   880,    11,
         1471,  2460,   351, 24088,    11,   393,   407, 25798,   284,   683,
           11,   314,  1183,   900, 17903,   287,   257, 14643,   286,  3869,
          290, 32405,  3998, 25286,  7278,  2402, 17903,    13, 50258, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259])

In [18]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir="./models/multitask_model",
        overwrite_output_dir=True,
        learning_rate=1e-5,
        do_train=True,
        num_train_epochs=3,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=2,  
        save_steps=3000,
        fp16=True,
        gradient_accumulation_steps=32,
        optim='adafactor'

    ),
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)
trainer.train()

Using amp half precision backend
***** Running training *****
  Num examples = 91975
  Num Epochs = 3
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 32
  Total optimization steps = 4311
 12%|█▏        | 500/4311 [48:28<6:06:12,  5.77s/it]

{'loss': 0.3962, 'learning_rate': 8.842495940617027e-06, 'epoch': 0.35}


 23%|██▎       | 1000/4311 [1:37:09<5:19:42,  5.79s/it]

{'loss': 0.2514, 'learning_rate': 7.68267223382046e-06, 'epoch': 0.7}


 35%|███▍      | 1500/4311 [2:25:40<4:29:26,  5.75s/it]

{'loss': 0.2229, 'learning_rate': 6.522848527023893e-06, 'epoch': 1.04}


 46%|████▋     | 2000/4311 [3:14:09<3:45:51,  5.86s/it]

{'loss': 0.179, 'learning_rate': 5.363024820227326e-06, 'epoch': 1.39}


 58%|█████▊    | 2500/4311 [4:02:47<2:55:00,  5.80s/it]

{'loss': 0.1717, 'learning_rate': 4.2032011134307586e-06, 'epoch': 1.74}


 70%|██████▉   | 3000/4311 [4:50:58<2:08:39,  5.89s/it]Saving model checkpoint to ./models/multitask_model/checkpoint-3000
Configuration saved in ./models/multitask_model/checkpoint-3000/config.json


{'loss': 0.1625, 'learning_rate': 3.043377406634192e-06, 'epoch': 2.09}


Model weights saved in ./models/multitask_model/checkpoint-3000/pytorch_model.bin
 81%|████████  | 3500/4311 [5:39:41<1:18:27,  5.80s/it]

{'loss': 0.1279, 'learning_rate': 1.8858733472512181e-06, 'epoch': 2.44}


 93%|█████████▎| 4000/4311 [6:27:47<29:58,  5.78s/it]  

{'loss': 0.1311, 'learning_rate': 7.260496404546508e-07, 'epoch': 2.78}


100%|██████████| 4311/4311 [6:57:52<00:00,  5.89s/it]

Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 4311/4311 [6:57:52<00:00,  5.82s/it]

{'train_runtime': 25072.8075, 'train_samples_per_second': 11.005, 'train_steps_per_second': 0.172, 'train_loss': 0.19967740208664064, 'epoch': 3.0}





TrainOutput(global_step=4311, training_loss=0.19967740208664064, metrics={'train_runtime': 25072.8075, 'train_samples_per_second': 11.005, 'train_steps_per_second': 0.172, 'train_loss': 0.19967740208664064, 'epoch': 3.0})

All done! Now, we can evaluate our multi-task model on all three tasks. In this case, we can simply use single-task data loaders, since we are evaluating each task individually.

We will use the (private) `_prediction_loop` method from the Trainer.

In [None]:
preds_dict = {}
for task_name in ["rte", "stsb", "commonsense_qa"]:
    eval_dataloader = DataLoaderWithTaskname(
        task_name,
        trainer.get_eval_dataloader(eval_dataset=features_dict[task_name]["validation"])
    )
    print(eval_dataloader.data_loader.collate_fn)
    preds_dict[task_name] = trainer._prediction_loop(
        eval_dataloader, 
        description=f"Validation: {task_name}",
    )

Now that we have all the predictions, let's go ahead and score them. The NLP library also has built-in metrics for the GLUE tasks (which includes STS-B and RTE), but not for Commonsense QA. Thankfully, Commonsense QA's evaluation metric is simple accuracy, which we can compute easily.

In [None]:
# Evalute RTE
nlp.load_metric('glue', name="rte").compute(
    np.argmax(preds_dict["rte"].predictions, axis=1),
    preds_dict["rte"].label_ids,
)

In [None]:
# Evalute STS-B
nlp.load_metric('glue', name="stsb").compute(
    preds_dict["stsb"].predictions.flatten(),
    preds_dict["stsb"].label_ids,
)

In [None]:
# Evalute Commonsense QA
np.mean(
    np.argmax(preds_dict["commonsense_qa"].predictions, axis=1)
    == preds_dict["commonsense_qa"].label_ids
)

You should expect scores of approximately:

* RTE: ~0.74
* STS-B: ~0.89/0.89
* Commonsense QA: ~0.60

These aren't award winning scores, nor are our tasks chosen for multi-task training synergy, but hopefully we have demonstrated how to do multi-task training with some of Hugging Face's latest offerings!

# An advertisement: Come check out jiant!

While the above recipe works, we saw what some of the frictions were: handling multi-task data loading, coercing the Trainer to work with multi-task inputs, and handling the featurization for each of the tasks.

If you are interested in more streamlined multi-task (or even single-task) fine-tuning work, we are building [jiant](https://jiant.info/), an NLP research-oriented library, built directly on the Transformers, where multi-task training is a first-class feature. `jiant` aims to facilitate cutting-edge NLP transfer learning research through broad task coverage and modular components, and we highly recommend using `jiant` for streamlined multi-task training workflows.

(If you've previously worked with `jiant`, we are currently undertaking [a complete rewrite](https://github.com/jiant-dev/jiant) to better support current research needs and engineering workflows.)

Click [here](https://jiant.info/) to learn more, or attend our system demo presentation at [ACL 2020](https://acl2020.org/program/accepted/#system-demonstrations).