In [None]:
!pip install --quiet transformers
!pip install --quiet nlp==0.2.0 

In [None]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import transformers
import nlp
from sklearn import preprocessing
import matplotlib.pyplot as plt
import logging
from sklearn.model_selection import train_test_split
logging.basicConfig(level=logging.CRITICAL)

import json
from tqdm.notebook import tqdm

tqdm.pandas()

In [None]:
!pip install datasets
import datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!curl 'https://miniodis-rproxy.lisn.upsaclay.fr/py3-private/public_data/6a27238f-1500-4f52-a57a-0782addd41ec/competition/5792/1/data/public_data.zip?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=EASNOMJFX9QFW4QIY4SL%2F20220721%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220721T014614Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=49179ae0578a4cbf92c5d30db2cba04e2c790ef1ef413f49b7a125f31335d011' --output data.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 9094k  100 9094k    0     0  6837k      0  0:00:01  0:00:01 --:--:-- 6837k


In [None]:
!unzip data.zip

Archive:  data.zip
replace train_dev/trac_2022_annotator_1_dev.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_1_train.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: a
error:  invalid response [a]
replace train_dev/trac_2022_annotator_1_train.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_2_dev.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_2_train.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_3_dev.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_3_train.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_4_dev.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_4_train.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_5_dev.tsv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n
replace train_dev/trac_2022_annotator_5_train.tsv? [

In [None]:
from os import listdir
df_train = pd.read_csv('/content/train_dev/trac_2022_annotator_1_train.tsv', sep = '\t')
df_val = pd.read_csv('/content/train_dev/trac_2022_annotator_1_dev.tsv', sep = '\t')
for i in listdir('/content/train_dev'):
  if ('train' in i) & ('1' not in i):
    df_temp = pd.read_csv('/content/train_dev/' + i, sep = '\t')
    df_train = pd.concat([df_train, df_temp]).reset_index(drop=True)
  
  if ('dev' in i) & ('1' not in i):
    df_temp = pd.read_csv('/content/train_dev/' + i, sep = '\t')
    df_val = pd.concat([df_val, df_temp]).reset_index(drop=True)
df_train.to_csv('trac_train.csv', index=False)
df_val.head(6000).to_csv('trac_val.csv', index=False)
df_val.tail(6000).to_csv('trac_test.csv', index = False)

In [None]:
class MultitaskModel(transformers.PreTrainedModel):
    def __init__(self, encoder, taskmodels_dict):
        """
        Setting MultitaskModel up as a PretrainedModel allows us
        to take better advantage of Trainer features
        """
        super().__init__(transformers.PretrainedConfig())

        self.encoder = encoder
        self.taskmodels_dict = nn.ModuleDict(taskmodels_dict)

    @classmethod
    def create(cls, model_name, model_type_dict, model_config_dict):
        """
        This creates a MultitaskModel using the model class and config objects
        from single-task models. 

        We do this by creating each single-task model, and having them share
        the same encoder transformer.
        """
        shared_encoder = None
        taskmodels_dict = {}
        for task_name, model_type in model_type_dict.items():
            model = model_type.from_pretrained(
                model_name, 
                config=model_config_dict[task_name],
            )
            if shared_encoder is None:
                print(cls.get_encoder_attr_name(model))
                shared_encoder = getattr(model, cls.get_encoder_attr_name(model))
            else:
                setattr(model, cls.get_encoder_attr_name(model), shared_encoder)
            taskmodels_dict[task_name] = model
        return cls(encoder=shared_encoder, taskmodels_dict=taskmodels_dict)

    @classmethod
    def get_encoder_attr_name(cls, model):
        """
        The encoder transformer is named differently in each model "architecture".
        This method lets us get the name of the encoder attribute
        """
        model_class_name = model.__class__.__name__
        if model_class_name.startswith("Bert"):
            return "bert"
        elif model_class_name.startswith("Roberta"):
            return "roberta"
        elif model_class_name.startswith("Albert"):
            return "albert"
        elif model_class_name.startswith("Deberta"):
            return "deberta"
        elif model_class_name.startswith("XLM"):
            return "roberta"
        else:
            raise KeyError(f"Add support for new model {model_class_name}")

    def forward(self, task_name, **kwargs):
        return self.taskmodels_dict[task_name](**kwargs)

In [None]:
dataset_dict = {
  "pan": datasets.load_dataset('hyperpartisan_news_detection', 'bypublisher'),
  "trac": datasets.load_dataset('csv', data_files={
      'train': '/content/trac_train.csv',
      'validation': '/content/trac_val.csv',
      'test': '/content/trac_test.csv',
  })
}



  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
model_name = "xlm-roberta-base"
multitask_model = MultitaskModel.create(
    model_name=model_name,
    model_type_dict={
        "pan": transformers.XLMRobertaForSequenceClassification,
        "trac": transformers.XLMRobertaForSequenceClassification,
    },
    model_config_dict={
        "pan": transformers.XLMRobertaConfig.from_pretrained(model_name, num_labels=5),
        "trac": transformers.XLMRobertaConfig.from_pretrained(model_name, num_labels=24, problem_type="multi_label_classification"),
    },
)

# model_name = "microsoft/deberta-base"
# multitask_model = MultitaskModel.create(
#     model_name=model_name,
#     model_type_dict={
#         "pan": transformers.AutoModelForSequenceClassification,
#         "trac": transformers.AutoModelForSequenceClassification,
#     },
#     model_config_dict={
#         "pan": transformers.AutoConfig.from_pretrained(model_name, num_labels=5),
#         "trac": transformers.AutoConfig.from_pretrained(model_name, num_labels=24),
#     },
# )

https://huggingface.co/xlm-roberta-base/resolve/main/config.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpdjifigff


Downloading:   0%|          | 0.00/615 [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-base/resolve/main/config.json in cache at /root/.cache/huggingface/transformers/87683eb92ea383b0475fecf99970e950a03c9ff5e51648d6eee56fb754612465.dfaaaedc7c1c475302398f09706cbb21e23951b73c6e2b3162c1c8a99bb3b62a
creating metadata file for /root/.cache/huggingface/transformers/87683eb92ea383b0475fecf99970e950a03c9ff5e51648d6eee56fb754612465.dfaaaedc7c1c475302398f09706cbb21e23951b73c6e2b3162c1c8a99bb3b62a
loading configuration file https://huggingface.co/xlm-roberta-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/87683eb92ea383b0475fecf99970e950a03c9ff5e51648d6eee56fb754612465.dfaaaedc7c1c475302398f09706cbb21e23951b73c6e2b3162c1c8a99bb3b62a
Model config XLMRobertaConfig {
  "architectures": [
    "XLMRobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id

Downloading:   0%|          | 0.00/1.04G [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-base/resolve/main/pytorch_model.bin in cache at /root/.cache/huggingface/transformers/97d0ea09f8074264957d062ec20ccb79af7b917d091add8261b26874daf51b5d.f42212747c1c27fcebaa0a89e2a83c38c6d3d4340f21922f892b88d882146ac2
creating metadata file for /root/.cache/huggingface/transformers/97d0ea09f8074264957d062ec20ccb79af7b917d091add8261b26874daf51b5d.f42212747c1c27fcebaa0a89e2a83c38c6d3d4340f21922f892b88d882146ac2
loading weights file https://huggingface.co/xlm-roberta-base/resolve/main/pytorch_model.bin from cache at /root/.cache/huggingface/transformers/97d0ea09f8074264957d062ec20ccb79af7b917d091add8261b26874daf51b5d.f42212747c1c27fcebaa0a89e2a83c38c6d3d4340f21922f892b88d882146ac2
Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_

roberta


Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForSequenceClassification: ['lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.out_p

In [None]:
!pip install sentencepiece

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 31.2 MB/s 
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.1.96


In [None]:
# tokenizer = transformers.XLMRobertaTokenizer.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file https://huggingface.co/xlm-roberta-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/87683eb92ea383b0475fecf99970e950a03c9ff5e51648d6eee56fb754612465.dfaaaedc7c1c475302398f09706cbb21e23951b73c6e2b3162c1c8a99bb3b62a
Model config XLMRobertaConfig {
  "_name_or_path": "xlm-roberta-base",
  "architectures": [
    "XLMRobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "xlm-roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_past": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.2

Downloading:   0%|          | 0.00/4.83M [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model in cache at /root/.cache/huggingface/transformers/9df9ae4442348b73950203b63d1b8ed2d18eba68921872aee0c3a9d05b9673c6.00628a9eeb8baf4080d44a0abe9fe8057893de20c7cb6e6423cddbf452f7d4d8
creating metadata file for /root/.cache/huggingface/transformers/9df9ae4442348b73950203b63d1b8ed2d18eba68921872aee0c3a9d05b9673c6.00628a9eeb8baf4080d44a0abe9fe8057893de20c7cb6e6423cddbf452f7d4d8
https://huggingface.co/xlm-roberta-base/resolve/main/tokenizer.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/transformers/tmpu7b4d_6z


Downloading:   0%|          | 0.00/8.68M [00:00<?, ?B/s]

storing https://huggingface.co/xlm-roberta-base/resolve/main/tokenizer.json in cache at /root/.cache/huggingface/transformers/daeda8d936162ca65fe6dd158ecce1d8cb56c17d89b78ab86be1558eaef1d76a.a984cf52fc87644bd4a2165f1e07e0ac880272c1e82d648b4674907056912bd7
creating metadata file for /root/.cache/huggingface/transformers/daeda8d936162ca65fe6dd158ecce1d8cb56c17d89b78ab86be1558eaef1d76a.a984cf52fc87644bd4a2165f1e07e0ac880272c1e82d648b4674907056912bd7
loading file https://huggingface.co/xlm-roberta-base/resolve/main/sentencepiece.bpe.model from cache at /root/.cache/huggingface/transformers/9df9ae4442348b73950203b63d1b8ed2d18eba68921872aee0c3a9d05b9673c6.00628a9eeb8baf4080d44a0abe9fe8057893de20c7cb6e6423cddbf452f7d4d8
loading file https://huggingface.co/xlm-roberta-base/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/daeda8d936162ca65fe6dd158ecce1d8cb56c17d89b78ab86be1558eaef1d76a.a984cf52fc87644bd4a2165f1e07e0ac880272c1e82d648b4674907056912bd7
loading file h

In [None]:
max_length = 128
def convert_to_pan_features(example_batch):
    features = tokenizer.batch_encode_plus(
        example_batch["text"], max_length=max_length, padding="max_length"
    )
    features["labels"] = example_batch["bias"]
    return features

def convert_to_trac_features(example_batch):
    features = tokenizer.batch_encode_plus(
        example_batch["Text"], max_length=max_length, padding="max_length"
    )
    features["ag"] = example_batch["Aggression"]
    features["ag_in"] = example_batch["Aggression Intensity"]
    features["dis"] = example_batch["Discursive Role"]
    features['gen'] = example_batch['Gender Bias']
    features["comm"] = example_batch["Communal Bias"]
    features["caste"] = example_batch["Caste/Class Bias"]
    features["race"] = example_batch["Ethnicity/Racial Bias"]
    return features

convert_func_dict = {
    "pan": convert_to_pan_features,
    "trac": convert_to_trac_features,
}

In [None]:
columns_dict = {
    "pan": ['input_ids', 'attention_mask', 'labels'],
    "trac": ['input_ids', 'attention_mask', 'ag', 'ag_in', 'dis', 'gen', 'comm', 'caste', 'race']
}

features_dict = {}
for task_name, dataset in dataset_dict.items():
    print(u"\u2192", task_name)
    features_dict[task_name] = {}
    for phase, phase_dataset in dataset.items():
        features_dict[task_name][phase] = phase_dataset.map(
            convert_func_dict[task_name],
            batched=True,
            load_from_cache_file=False,
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))
        features_dict[task_name][phase].set_format(
            type="torch", 
            columns=columns_dict[task_name],
        )
        print(task_name, phase, len(phase_dataset), len(features_dict[task_name][phase]))

→ pan


  0%|          | 0/600 [00:00<?, ?ba/s]

pan train 600000 600000
pan train 600000 600000


  0%|          | 0/600 [00:00<?, ?ba/s]

pan validation 600000 600000
pan validation 600000 600000
→ trac


  0%|          | 0/129 [00:00<?, ?ba/s]

trac train 128575 128575
trac train 128575 128575


  0%|          | 0/6 [00:00<?, ?ba/s]

trac validation 6000 6000
trac validation 6000 6000


  0%|          | 0/6 [00:00<?, ?ba/s]

trac test 6000 6000
trac test 6000 6000


In [None]:
import dataclasses
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import DataCollator, InputDataClass, DefaultDataCollator
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from typing import List, Union, Dict


class NLPDataCollator(DefaultDataCollator):
    """
    Extending the existing DataCollator to work with NLP dataset batches
    """
    def collate_batch(self, features: List[Union[InputDataClass, Dict]]) -> Dict[str, torch.Tensor]:
        first = features[0]
        if isinstance(first, dict):
          # NLP data sets current works presents features as lists of dictionary
          # (one per example), so we  will adapt the collate_batch logic for that
          if "labels" in first and first["labels"] is not None:
              if first["labels"].dtype == torch.int64:
                  labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
              else:
                  labels = torch.tensor([f["labels"] for f in features], dtype=torch.float)
              batch = {"labels": labels}
          for k, v in first.items():
              if k != "labels" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
                  
          if "ag" in first and first["ag"] is not None:
              if first["ag"].dtype == torch.int64:
                  ag = torch.tensor([f["ag"] for f in features], dtype=torch.long)
              else:
                  ag = torch.tensor([f["ag"] for f in features], dtype=torch.float)
              batch = {"ag": ag}
          for k, v in first.items():
              if k != "ag" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
                  
          if "ag_in" in first and first["ag_in"] is not None:
              if first["ag_in"].dtype == torch.int64:
                  ag_in = torch.tensor([f["ag_in"] for f in features], dtype=torch.long)
              else:
                  ag_in = torch.tensor([f["ag_in"] for f in features], dtype=torch.float)
              batch = {"ag_in": ag_in}
          for k, v in first.items():
              if k != "ag_in" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
                  
          if "dis" in first and first["dis"] is not None:
              if first["dis"].dtype == torch.int64:
                  dis = torch.tensor([f["dis"] for f in features], dtype=torch.long)
              else:
                  dis = torch.tensor([f["dis"] for f in features], dtype=torch.float)
              batch = {"dis": dis}
          for k, v in first.items():
              if k != "dis" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])

          if "gen" in first and first["gen"] is not None:
              if first["gen"].dtype == torch.int64:
                  gen = torch.tensor([f["gen"] for f in features], dtype=torch.long)
              else:
                  gen = torch.tensor([f["gen"] for f in features], dtype=torch.float)
              batch = {"gen": gen}
          for k, v in first.items():
              if k != "gen" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
                  
          if "comm" in first and first["comm"] is not None:
              if first["comm"].dtype == torch.int64:
                  comm = torch.tensor([f["comm"] for f in features], dtype=torch.long)
              else:
                  comm = torch.tensor([f["comm"] for f in features], dtype=torch.float)
              batch = {"comm": comm}
          for k, v in first.items():
              if k != "comm" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
                  
          if "caste" in first and first["caste"] is not None:
              if first["caste"].dtype == torch.int64:
                  caste = torch.tensor([f["caste"] for f in features], dtype=torch.long)
              else:
                  caste = torch.tensor([f["caste"] for f in features], dtype=torch.float)
              batch = {"caste": caste}
          for k, v in first.items():
              if k != "caste" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
                  
          if "race" in first and first["race"] is not None:
              if first["race"].dtype == torch.int64:
                  race = torch.tensor([f["race"] for f in features], dtype=torch.long)
              else:
                  race = torch.tensor([f["race"] for f in features], dtype=torch.float)
              batch = {"race": race}
          for k, v in first.items():
              if k != "race" and v is not None and not isinstance(v, str):
                  batch[k] = torch.stack([f[k] for f in features])
      
          return batch
        else:
          # otherwise, revert to using the default collate_batch
          return DefaultDataCollator().collate_batch(features)


class StrIgnoreDevice(str):
    """
    This is a hack. The Trainer is going call .to(device) on every input
    value, but we need to pass in an additional `task_name` string.
    This prevents it from throwing an error
    """
    def to(self, device):
        return self


class DataLoaderWithTaskname:
    """
    Wrapper around a DataLoader to also yield a task name
    """
    def __init__(self, task_name, data_loader):
        self.task_name = task_name
        self.data_loader = data_loader

        self.batch_size = data_loader.batch_size
        self.dataset = data_loader.dataset

    def __len__(self):
        return len(self.data_loader)
    
    def __iter__(self):
        for batch in self.data_loader:
            batch["task_name"] = StrIgnoreDevice(self.task_name)
            yield batch


class MultitaskDataloader:
    """
    Data loader that combines and samples from multiple single-task
    data loaders.
    """
    def __init__(self, dataloader_dict):
        self.dataloader_dict = dataloader_dict
        self.num_batches_dict = {
            task_name: len(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        self.task_name_list = list(self.dataloader_dict)
        self.dataset = [None] * sum(
            len(dataloader.dataset) 
            for dataloader in self.dataloader_dict.values()
        )

    def __len__(self):
        return sum(self.num_batches_dict.values())

    def __iter__(self):
        """
        For each batch, sample a task, and yield a batch from the respective
        task Dataloader.

        We use size-proportional sampling, but you could easily modify this
        to sample from some-other distribution.
        """
        task_choice_list = []
        for i, task_name in enumerate(self.task_name_list):
            task_choice_list += [i] * self.num_batches_dict[task_name]
        task_choice_list = np.array(task_choice_list)
        np.random.shuffle(task_choice_list)
        dataloader_iter_dict = {
            task_name: iter(dataloader) 
            for task_name, dataloader in self.dataloader_dict.items()
        }
        for task_choice in task_choice_list:
            task_name = self.task_name_list[task_choice]
            yield next(dataloader_iter_dict[task_name])    

class MultitaskTrainer(transformers.Trainer):

    def get_single_train_dataloader(self, task_name, train_dataset):
        """
        Create a single-task data loader that also yields task names
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        else:
            train_sampler = (
                RandomSampler(train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(train_dataset)
            )

        data_loader = DataLoaderWithTaskname(
            task_name=task_name,
            data_loader=DataLoader(
              train_dataset,
              batch_size=self.args.train_batch_size,
              sampler=train_sampler,
              collate_fn=self.data_collator.collate_batch,
            ),
        )
        return data_loader

    def get_train_dataloader(self):
        """
        Returns a MultitaskDataloader, which is not actually a Dataloader
        but an iterable that returns a generator that samples from each 
        task Dataloader
        """
        return MultitaskDataloader({
            task_name: self.get_single_train_dataloader(task_name, task_dataset)
            for task_name, task_dataset in self.train_dataset.items()
        })


In [None]:
train_dataset = {
    task_name: dataset["train"] 
    for task_name, dataset in features_dict.items()
}
trainer = MultitaskTrainer(
    model=multitask_model,
    args=transformers.TrainingArguments(
        output_dir="multitask_model",
        overwrite_output_dir=True,
        learning_rate=1e-5,
        do_train=True,
        num_train_epochs=2,
        # Adjust batch size if this doesn't fit on the Colab GPU
        per_device_train_batch_size=2,  
        save_steps=30000,
    ),
    data_collator=NLPDataCollator(),
    train_dataset=train_dataset,
) 
trainer.train()