Code adapted from [Huggingface](https://colab.research.google.com/github/zphang/zphang.github.io/blob/master/files/notebooks/Multi_task_Training_with_Transformers_NLP.ipynb)

In [None]:
!pip install transformers datasets

In [2]:
import numpy as np
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
import logging
logging.basicConfig(level=logging.INFO)

In [3]:
dataset_dict = {
    "entailment": load_dataset('glue', "rte"),
    "event_detection": load_dataset('csv', data_files="event_detection.csv"),
}



  0%|          | 0/3 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
dataset_dict

{'entailment': DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 2490
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 277
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3000
    })
}),
 'event_detection': DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 1743
    })
})}

In [5]:
def combine_sentences(example):
  example['sentence'] = example['sentence1'] + ' ' + example['sentence2']
  return example

In [6]:
dataset_dict['entailment']['train'] = dataset_dict['entailment']['train'].map(lambda examples: combine_sentences(examples), remove_columns=['sentence1', 'sentence2', 'idx'])
dataset_dict['entailment']['validation'] = dataset_dict['entailment']['validation'].map(lambda examples: combine_sentences(examples), remove_columns=['sentence1', 'sentence2', 'idx'])
dataset_dict['entailment']['test'] = dataset_dict['entailment']['test'].map(lambda examples: combine_sentences(examples), remove_columns=['sentence1', 'sentence2', 'idx'])



  0%|          | 0/2490 [00:00<?, ?ex/s]

  0%|          | 0/277 [00:00<?, ?ex/s]

  0%|          | 0/3000 [00:00<?, ?ex/s]

In [7]:
dataset_dict

{'entailment': DatasetDict({
    train: Dataset({
        features: ['label', 'sentence'],
        num_rows: 2490
    })
    validation: Dataset({
        features: ['label', 'sentence'],
        num_rows: 277
    })
    test: Dataset({
        features: ['label', 'sentence'],
        num_rows: 3000
    })
}),
 'event_detection': DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 1743
    })
})}

In [8]:
class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        """
        Setting MultitaskModel up as a PretrainedModel allows us
        to take better advantage of Trainer features
        """
        super().__init__(transformers.PretrainedConfig())

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        """
        This creates a MultitaskModel using the model class and config objects
        from single-task models. 

        We do this by creating each single-task model, and having them share
        the same encoder transformer.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                shared_encoder = getattr(model, cls.get_encoder_attr_name(model))
            else:
                setattr(model, cls.get_encoder_attr_name(model), shared_encoder)
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)

    @classmethod
    def get_encoder_attr_name(cls, model):
        """
        The encoder transformer is named differently in each model "architecture".
        This method lets us get the name of the encoder attribute
        """
        model_class_name = model.__class__.__name__
        if model_class_name.startswith("Bert"):
            return "bert"
        elif model_class_name.startswith("Roberta"):
            return "roberta"
        elif model_class_name.startswith("Albert"):
            return "albert"
        elif model_class_name.startswith("Electra"):
            return "electra"
        else:
            raise KeyError(f"Add support for new model {model_class_name}")

    def forward(self, task_name, **kwargs):
        return self.taskmodels_dict[task_name](**kwargs)

In [9]:
model_name = "google/electra-base-discriminator"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "entailment": transformers.AutoModelForSequenceClassification,
        "event_detection": transformers.AutoModelForSequenceClassification,
    },
    model_config_dict={
        "entailment": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
        "event_detection": transformers.AutoConfig.from_pretrained(model_name, num_labels=2),
    },
)

Some weights of the model checkpoint at google/electra-base-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.d

In [10]:
if model_name.startswith("google/electra"):
    print(multitask_model.encoder.embeddings.word_embeddings.weight.data_ptr())
    print(multitask_model.taskmodels_dict["entailment"].electra.embeddings.word_embeddings.weight.data_ptr())
    print(multitask_model.taskmodels_dict["event_detection"].electra.embeddings.word_embeddings.weight.data_ptr())
else:
    print("Exercise for the reader: add a check for other model architectures =)")

560701440
560701440
560701440


In [11]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [12]:
max_length = 128

def convert_to_features(example_batch):
    inputs = list(example_batch['sentence'])
    features = tokenizer.batch_encode_plus(
        inputs, max_length=max_length, pad_to_max_length=True
    )
    features["labels"] = example_batch["label"]
    return features

convert_func_dict = {
    "entailment": convert_to_features,
    "event_detection": convert_to_features,
}

In [13]:
columns_dict = {
    "entailment": ['input_ids', 'attention_mask', 'labels'],
    "event_detection": ['input_ids', 'attention_mask', 'labels'],
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

  0%|          | 0/3 [00:00<?, ?ba/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


entailment train 2490 2490
entailment train 2490 2490


  0%|          | 0/1 [00:00<?, ?ba/s]

entailment validation 277 277
entailment validation 277 277


  0%|          | 0/3 [00:00<?, ?ba/s]

entailment test 3000 3000
entailment test 3000 3000


  0%|          | 0/2 [00:00<?, ?ba/s]

event_detection train 1743 1743
event_detection train 1743 1743


In [14]:
import dataclasses
from torch.utils.data import DataLoader
from transformers.data.data_collator import DataCollator, InputDataClass
from transformers import DefaultDataCollator
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict


class NLPDataCollator(DefaultDataCollator):
    def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        first = features[0]
        if isinstance(first, dict):
          # NLP data sets current works presents features as lists of dictionary
          # (one per example), so we  will adapt the collate_batch logic for that
          if "labels" in first and first["labels"] is not None:
              if first["labels"].dtype == torch.int64:
                  labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
              else:
                  labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
              batch = {"labels": labels}
          for k, v in first.items():
              if k != "labels" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
          return batch
        else:
          # otherwise, revert to using the default collate_batch
          return DefaultDataCollator().collate_batch(features)


class StrIgnoreDevice(str):
    """
    This is a hack. The Trainer is going call .to(device) on every input
    value, but we need to pass in an additional `task_name` string.
    This prevents it from throwing an error
    """
    def to(self, device):
        return self


class DataLoaderWithTaskname:
    """
    Wrapper around a DataLoader to also yield a task name
    """
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.

        We use size-proportional sampling, but you could easily modify this
        to sample from some-other distribution.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    

class MultitaskTrainer(transformers.Trainer):

    def get_single_train_dataloader(self, task_name, train_dataset):
        """
        Create a single-task data loader that also yields task names
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = (
            RandomSampler(train_dataset)
            if self.args.local_rank == -1
            else DistributedSampler(train_dataset)
        )

        data_loader = DataLoaderWithTaskname(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )
        return data_loader

    def get_train_dataloader(self):
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })

In [16]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir="./models/multitask_model",
        overwrite_output_dir=True,
        learning_rate=1e-5,
        do_train=True,
        num_train_epochs=5,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=8,  
        save_steps=3000,
    ),
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
)
trainer.train()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
***** Running training *****
  Num examples = 4233
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2650


Step,Training Loss
500,0.4504
1000,0.3512
1500,0.2556
2000,0.1732
2500,0.1229




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=2650, training_loss=0.2604896358274064, metrics={'train_runtime': 235.0228, 'train_samples_per_second': 90.055, 'train_steps_per_second': 11.276, 'total_flos': 1401811279027200.0, 'train_loss': 0.2604896358274064, 'epoch': 5.0})

In [17]:
multitask_model.save_pretrained("model")

Configuration saved in model/config.json
Model weights saved in model/pytorch_model.bin
