*Copyright (c) Microsoft Corporation. All rights reserved.*

*Licensed under the MIT License.*

# Text Classification of MultiNLI Sentences using MT-DNN

This notebook utilizes a PyTorch package that implements the Multi-Task Deep Neural Network Toolkit (MTDNN) for Natural Language Understanding.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
import sys
import torch

from tempfile import TemporaryDirectory
from utils_nlp.dataset.multinli import download_tsv_files_and_extract

from mtdnn.common.types import EncoderModelType
from mtdnn.configuration_mtdnn import MTDNNConfig
from mtdnn.modeling_mtdnn import MTDNNModel
from mtdnn.process_mtdnn import MTDNNDataProcess
from mtdnn.tasks.config import MTDNNTaskDefs
from mtdnn.data_builder_mtdnn import MTDNNDataBuilder
from mtdnn.tokenizer_mtdnn import MTDNNTokenizer

  import pandas.util.testing as tm


## Define Configuration, Tasks and Model Objects

## Introduction
In this notebook, we fine-tune and evaluate MT-DNN models on a subset of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset.


In [3]:
# Define Configuration, Tasks and Model Objects
ROOT_DIR = TemporaryDirectory().name
OUTPUT_DIR = os.path.join(ROOT_DIR, 'checkpoint')
os.makedirs(OUTPUT_DIR) if not os.path.exists(OUTPUT_DIR) else OUTPUT_DIR
DATA_DIR = os.path.join(ROOT_DIR, 'data')
os.makedirs(DATA_DIR) if not os.path.exists(DATA_DIR) else DATA_DIR

BATCH_SIZE = 64
NUM_EPOCHS = 1
TRAIN_DATA_FRACTION = 0.05
TEST_DATA_FRACTION = 0.05
TRAIN_SIZE = 0.75

print(DATA_DIR)
print(OUTPUT_DIR)

## Read Dataset
We start by loading a subset of the data. The following function also downloads and extracts the files, if they don't exist in the data folder.

The MultiNLI dataset is mainly used for natural language inference (NLI) tasks, where the inputs are sentence pairs and the labels are entailment indicators. The sentence pairs are also classified into *genres* that allow for more coverage and better evaluation of NLI models.

For our classification task, we use the first sentence only as the text input, and the corresponding genre as the label. We select the examples corresponding to one of the entailment labels (*neutral* in this case) to avoid duplicate rows, as the sentences are not unique, whereas the sentence pairs are.

In [4]:
download_tsv_files_and_extract(DATA_DIR)

100%|██████████| 305k/305k [00:07<00:00, 43.5kKB/s] 


Downloaded file to:  /tmp/tmp4rgav0t_/data/MNLI


### Define a Configuration Object 

Create a model configuration object, `MTDNNConfig`, with the necessary parameters to initialize the MT-DNN model. Initialization without any parameters will default to a similar configuration that initializes a BERT model. 

In [5]:
config = MTDNNConfig(batch_size=BATCH_SIZE)


### Create Task Definition Object  

Define the task parameters to train for and initialize an `MTDNNTaskDefs` object. Create a task parameter dictionary. Definition can be a single or multiple tasks to train.  `MTDNNTaskDefs` can take a python dict, yaml or json file with task(s) defintion.

The data source directory is the path of data downloaded and extracted above using `download_tsv_files_and_extract` which is the `MNLI` dir under the `DATA_DIR` temporary directory.    

The data source has options that are set to drive each task pre-processing; `data_process_opts`


In [6]:
data_source_dir = os.path.join(DATA_DIR, "MNLI")
tasks_params = {
    "mnli": {
        "data_format": "PremiseAndOneHypothesis",
        "encoder_type": "BERT",
        "dropout_p": 0.3,
        "enable_san": True,
        "labels": ["contradiction", "neutral", "entailment"],
        "metric_meta": ["ACC"],
        "loss": "CeCriterion",
        "kd_loss": "MseCriterion",
        "n_class": 3,
        "split_names": [
            "train",
            "dev_matched",
            "dev_mismatched",
            "test_matched",
            "test_mismatched",
        ],
        "data_source_dir": data_source_dir,
        "data_process_opts": {"header": True, "is_train": True, "multi_snli": False,},
        "task_type": "Classification",
    },
}

# Define the tasks
task_defs = MTDNNTaskDefs(tasks_params)

06/19/2020 01:37:36 - mtdnn.tasks.config - INFO - Mapping Task attributes
06/19/2020 01:37:36 - mtdnn.tasks.config - INFO - Configured task definitions - ['mnli']



### Create the MTDNN Data Tokenizer Object  

Create a data tokenizing object, `MTDNNTokenizer`. Based on the model initial checkpoint, it wraps around the model's Huggingface transformers library to encode the data to MT-DNN format. This becomes the input to the data building stage.  


In [7]:
tokenizer = MTDNNTokenizer()

#### Testing out the Tokenizer encode function on a sample text
`tokenizer.encode("He is a boy", "what is he")`

In [8]:
print(tokenizer.encode("He is a boy", "what is he"))

([101, 100, 2003, 1037, 2879, 102, 2054, 2003, 2002, 102], None, [0, 0, 0, 0, 0, 0, 1, 1, 1, 1])


### Create the Data Builder Object  

Create a data preprocessing object, `MTDNNDataBuilder`. This class is responsible for converting the data into the MT-DNN format depending on the task.  
 

Define a data builder that handles the creating of each task's vectorized data utilizing the model tokenizer. This will build out the vectorized data needed for creating the training, test and development PyTorch dataloaders

In [9]:
## Load and build data
data_builder = MTDNNDataBuilder(
    tokenizer=tokenizer,
    task_defs=task_defs,
    data_dir="/home/useradmin/sources/mt-dnn-orig/data",
    canonical_data_suffix="canonical_data_2",
    dump_rows=True,
)

## Build data to MTDNN Format
## Iterable of each specific task and processed data
vectorized_data = data_builder.vectorize()

06/19/2020 01:37:39 - mtdnn.data_builder_mtdnn - INFO - Sucessfully loaded and built 392702 samples for mnli at /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/mnli_train.tsv
06/19/2020 01:37:39 - mtdnn.data_builder_mtdnn - INFO - Sucessfully loaded and built 9815 samples for mnli at /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/mnli_dev_matched.tsv
06/19/2020 01:37:39 - mtdnn.data_builder_mtdnn - INFO - Sucessfully loaded and built 9832 samples for mnli at /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/mnli_dev_mismatched.tsv
06/19/2020 01:37:39 - mtdnn.data_builder_mtdnn - INFO - Sucessfully loaded and built 9796 samples for mnli at /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/mnli_test_matched.tsv
06/19/2020 01:37:39 - mtdnn.data_builder_mtdnn - INFO - Sucessfully loaded and built 9847 samples for mnli at /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/mnli_test_mismatched.tsv
mnli_train
06/19/2020 01:37:39 - mtdnn.data_builder

Building Data For Premise and One Hypothesis: 392702it [03:57, 1654.72it/s]

06/19/2020 01:41:38 - mtdnn.data_builder_mtdnn - INFO - Saving data to /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/bert_base_uncased/mnli_train.json



Saving Data For PremiseAndOneHypothesis: 100%|██████████| 392702/392702 [00:05<00:00, 67824.04it/s]


mnli_dev_matched
06/19/2020 01:41:47 - mtdnn.data_builder_mtdnn - INFO - Building Data For 'MNLI DEV MATCHED' Task


Building Data For Premise and One Hypothesis: 9815it [00:06, 1608.14it/s]

06/19/2020 01:41:53 - mtdnn.data_builder_mtdnn - INFO - Saving data to /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/bert_base_uncased/mnli_dev_matched.json



Saving Data For PremiseAndOneHypothesis: 100%|██████████| 9815/9815 [00:00<00:00, 67040.72it/s]

mnli_dev_mismatched
06/19/2020 01:41:53 - mtdnn.data_builder_mtdnn - INFO - Building Data For 'MNLI DEV MISMATCHED' Task



Building Data For Premise and One Hypothesis: 9832it [00:06, 1625.22it/s]

06/19/2020 01:41:59 - mtdnn.data_builder_mtdnn - INFO - Saving data to /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/bert_base_uncased/mnli_dev_mismatched.json



Saving Data For PremiseAndOneHypothesis: 100%|██████████| 9832/9832 [00:00<00:00, 71501.71it/s]

mnli_test_matched
06/19/2020 01:41:59 - mtdnn.data_builder_mtdnn - INFO - Building Data For 'MNLI TEST MATCHED' Task



Building Data For Premise and One Hypothesis: 9796it [00:05, 1693.61it/s]

06/19/2020 01:42:05 - mtdnn.data_builder_mtdnn - INFO - Saving data to /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/bert_base_uncased/mnli_test_matched.json



Saving Data For PremiseAndOneHypothesis: 100%|██████████| 9796/9796 [00:00<00:00, 75186.93it/s]

mnli_test_mismatched
06/19/2020 01:42:05 - mtdnn.data_builder_mtdnn - INFO - Building Data For 'MNLI TEST MISMATCHED' Task



Building Data For Premise and One Hypothesis: 9847it [00:06, 1619.39it/s]

06/19/2020 01:42:11 - mtdnn.data_builder_mtdnn - INFO - Saving data to /home/useradmin/sources/mt-dnn-orig/data/canonical_data_2/bert_base_uncased/mnli_test_mismatched.json



Saving Data For PremiseAndOneHypothesis: 100%|██████████| 9847/9847 [00:00<00:00, 67734.49it/s]


### Create the Data Processing Object  

Create a data preprocessing object, `MTDNNDataProcess`. This creates the training, test and development PyTorch dataloaders needed for training and testing. We also need to retrieve the necessary training options required to initialize the model correctly, for all tasks.  

Define a data process that handles creating the training, test and development PyTorch dataloaders

In [10]:
# Make the Data Preprocess step and update the config with training data updates
data_processor = MTDNNDataProcess(
    config=config, task_defs=task_defs, vectorized_data=vectorized_data
)

06/19/2020 01:42:11 - mtdnn.process_mtdnn - INFO - Starting to process the training data sets
06/19/2020 01:42:11 - mtdnn.process_mtdnn - INFO - Loading mnli_train as task 0
06/19/2020 01:42:12 - mtdnn.dataset_mtdnn - INFO - Loaded 392702 samples out of 392702
06/19/2020 01:42:12 - mtdnn.process_mtdnn - INFO - Starting to process the testing data sets
06/19/2020 01:42:12 - mtdnn.process_mtdnn - INFO - Loading mnli_dev_matched as task 0
06/19/2020 01:42:12 - mtdnn.dataset_mtdnn - INFO - Loaded 9815 samples out of 9815
06/19/2020 01:42:12 - mtdnn.process_mtdnn - INFO - Loading mnli_dev_mismatched as task 0
06/19/2020 01:42:12 - mtdnn.dataset_mtdnn - INFO - Loaded 9832 samples out of 9832
06/19/2020 01:42:12 - mtdnn.process_mtdnn - INFO - Loading mnli_test_matched as task 0
06/19/2020 01:42:12 - mtdnn.dataset_mtdnn - INFO - Loaded 9796 samples out of 9796
06/19/2020 01:42:12 - mtdnn.process_mtdnn - INFO - Loading mnli_test_mismatched as task 0
06/19/2020 01:42:12 - mtdnn.dataset_mtdnn - I

Retrieve the processed batch multitask batch data loaders for training, development and test

In [11]:
multitask_train_dataloader = data_processor.get_train_dataloader()
dev_dataloaders_list = data_processor.get_dev_dataloaders()
test_dataloaders_list = data_processor.get_test_dataloaders()

Now we can retrieve the training options, from the processor, to initialize model with.

In [12]:
decoder_opts = data_processor.get_decoder_options_list()
task_types = data_processor.get_task_types_list()
dropout_list = data_processor.get_tasks_dropout_prob_list()
loss_types = data_processor.get_loss_types_list()
kd_loss_types = data_processor.get_kd_loss_types_list()
tasks_nclass_list = data_processor.get_task_nclass_list()

Let us update the batch steps

In [13]:
num_all_batches = data_processor.get_num_all_batches()

### Instantiate the MTDNN Model

Now we can go ahead and create an `MTDNNModel` model

In [14]:
model = MTDNNModel(
    config,
    task_defs,
    pretrained_model_name="bert-base-uncased",
    num_train_step=num_all_batches,
    decoder_opts=decoder_opts,
    task_types=task_types,
    dropout_list=dropout_list,
    loss_types=loss_types,
    kd_loss_types=kd_loss_types,
    tasks_nclass_list=tasks_nclass_list,
    multitask_train_dataloader=multitask_train_dataloader,
    dev_dataloaders_list=dev_dataloaders_list,
    test_dataloaders_list=test_dataloaders_list,
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…


idx: 0, number of task labels: 3


### Fit on one epoch and predict using the training and test  

At this point the MT-DNN model allows us to fit to the model and create predictions. The fit takes an optional `epochs` parameter that overwrites the epochs set in the `MTDNNConfig` object. 

In [23]:
multitask_train_dataloader

<mtdnn.dataset_mtdnn.MTDNNMultiTaskDataset at 0x7f425aa13a58>