# MMTL BERT Tutorial

NOTE: This tutorial assumes that you have already completed the MMTL Basics tutorial.

In this Tutorial, we demonstrate how the Snorkel MeTaL MMTL package can be used in a more advanced setting: more tasks, more complex modules (e.g., pretrained BERT), more complex data formatting (multiple fields), etc.

## Environment Setup

As usual, we'll first make sure that we can import from `metal` alright.

In [1]:
# Confirm we can import from metal
import sys
sys.path.append('../../metal')
import metal

# Import other dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set random seed for notebook
SEED = 123

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Load GLUE Tasks

We now load the data for 

### UNUSED VERBIAGE:

In this notebook, for simplicity and ease of visualization, we will solve two very simple geometric tasks; (see the `MMTL_BERT` tutorial for an example with more complex tasks):
* Task 1: is 

Key Definitions:
* **`Payload`**: A `Payload` is a set of instances (data points) and one or more corresponding label sets. In standard single-task settings or vanilla multi-task settings, each dataset is represented by a separate `Payload` with a single label set: the ground truth labels. In other settings, however, the same set of instances can have labels for multiple tasks (e.g., sentences with labels for both sentiment classification and topic classification) and/or multiple labels for the same task (e.g., one to three labels from different crowdworkers).

* **`Task`**: A `Task` is a path through a network. In MeTaL, this corresponds to a particular sequence of Pytorch modules that each instance will pass through, ending with a "task head" module that outputs a prediction for that instance on that task.  `Task` objects are not restricted to work with only one `Payload` or label set.


For illustration, consider two datasets: one containing tweets (the `Tweet` dataset) and the other containing product reviews (`Reviews`). In both cases, our goal is to predict the sentiment of the "document" (the tweet or review). However, our tweets have a label space of cardinality 2 (negative, positive), while our reviews have a label space of cardinality 3 (negative, neutral, positive). While there were certainly similarities between the two datasets (i.e., the word "awful" will suggest a negative sentiment in both domains), there are also differences (i.e., average number of words, different vocabulary distributions, etc.). This is a great candidate for multi-task learning.

We start by defining a `Payload`. 

A `Payload` is a bundle of instances (data points) and one or more corresponding label sets. These are wrapped together in a Pytorch `DataLoader` that returns batches of data. A `Payload` also specifies the split that the data belongs to (e.g., 'train' or 'test') and a dictionary mapping each `LabelSet` to its corresponding `Task`.

If the two problems in our MTL setup had disjoint instance sets, we would use two `Payloads`. As it is, 

usehave two related datasets with disjoint instances and only one label set---the gold labels---per instance set. This can be easily modeled using two `Payloads`.

 If your data has only one field (e.g., it is simply a Tensor), the default field_name is "data". Two advantages of storing data as a dictionary are that (1) it allows for additional fields to be added later with very few changes to the code, and (2) it supports working with models (such as BERT in NLP) that require multiple inputs per instance (e.g., token ids as well as segment ids).
 
 Instances and label sets are stored together in a Pytorch `Dataset` object.
The `Dataset` is initilialized with an X dict and a Y dict:
* The X dict is of the form {`field_name`: `values`}. The default field name is simply "data".
* The Y dict of the form {`label_name`: `values`}.  The default label set name is simply "labels".

For example, metrics are reported in the form "task/payload/label_set/metric"
