In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
task_names = [
    "COLA", 
#     "SST2",
#     "MNLI",
#     "RTE",
#     "WNLI",
#     "QQP",
#     "MRPC",
#     "STSB",
#     "QNLI",
#     "SPACY_POS",
#     "SPACY_NER",
#     "THIRD",
#     "BLEU",
]

FILENAME = "COLA_tasks_and_payloads"

In [3]:
SEED = 1

## Load previously trained model
Hint: make sure the `bert_model` is initialized correctly!

In [4]:
%%time
from metal.mmtl.glue.glue_tasks import create_tasks_and_payloads

# Create tasks and payloads
tasks, payloads = create_tasks_and_payloads(
    task_names,
    dl_kwargs={"batch_size": 16},
    freeze_bert=True,
    bert_model='bert-large-cased'
)

Using random seed: 491344
Loading COLA Dataset


HBox(children=(IntProgress(value=0, max=8550), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1042), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1063), HTML(value='')))


CPU times: user 24.4 s, sys: 6.69 s, total: 31.1 s
Wall time: 28.1 s


In [5]:
tasks, payloads

([ClassificationTask(name=COLA, loss_multiplier=1.0)],
 [Payload(COLA_train: tasks=[COLA], split=train),
  Payload(COLA_valid: tasks=[COLA], split=valid),
  Payload(COLA_test: tasks=[COLA], split=test)])

In [6]:
from metal.mmtl.metal_model import MetalModel

model = MetalModel(tasks, seed=SEED, verbose=False)

In [7]:
import os
import torch
model_dir = '/dfs/scratch0/chami/metal/metal/mmtl/aws/output/2019_03_14_01_58_14/0/logdir/bert_large/QNLI.STSB.MRPC.QQP.RTE.MNLI.SST2.COLA.WNLI_09_15_09'
model_path = os.path.join(model_dir, 'best_model.pth')
device = torch.device(f"cuda:0")
model.load_state_dict(torch.load(model_path, map_location=device)["model"], strict=False)

#### Sanity check that task head is trained!

In [8]:
model.score(payloads[1])

{'COLA/COLA_valid/accuracy': 0.8464491362763915,
 'COLA/COLA_valid/matthews_corr': 0.6311800125409577}

## Define slices for evaluation

In [9]:
%%time
from metal.mmtl.glue.glue_tasks import create_tasks_and_payloads

# define slices
slice_dict = {  # A map of the slices that apply to each task
   "COLA": ["ends_with_question_mark"]
}

# Create tasks and payloads
_, payloads_slice = create_tasks_and_payloads(
    task_names,
    dl_kwargs={"batch_size": 16},
    slice_dict=slice_dict,
    freeze_bert=True,
    bert_model='bert-large-cased'
)

Using random seed: 128037
Loading COLA Dataset


HBox(children=(IntProgress(value=0, max=8550), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1042), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1063), HTML(value='')))


Added label_set with 8550 labels for task COLA:ends_with_question_mark to payload COLA_train.
Added label_set with 1042 labels for task COLA:ends_with_question_mark to payload COLA_valid.
Added label_set with 1063 labels for task COLA:ends_with_question_mark to payload COLA_test.
CPU times: user 18.2 s, sys: 2.33 s, total: 20.5 s
Wall time: 21.9 s


In [10]:
payloads_slice

[Payload(COLA_train: tasks=[COLA,COLA:ends_with_question_mark], split=train),
 Payload(COLA_valid: tasks=[COLA,COLA:ends_with_question_mark], split=valid),
 Payload(COLA_test: tasks=[COLA,COLA:ends_with_question_mark], split=test)]

#### Sanity check the number of examples in the train set

In [11]:
import numpy as np
def count_num_labels(labels):
    return np.sum(np.array(labels) != 0)

In [12]:
dataset = payloads_slice[0].data_loader.dataset
for labelset_name, labels in dataset.labels.items():
    print(labelset_name, count_num_labels(labels))

COLA 8550
COLA:ends_with_question_mark 615


### Evaluate baseline model on the slice of interest

In [13]:
model.score(payloads_slice[1])

Evaluating 57 / 1042 active labels


{'COLA/COLA_valid/accuracy': 0.8464491362763915,
 'COLA/COLA_valid/matthews_corr': 0.6311800125409577,
 'COLA:ends_with_question_mark/COLA_valid/accuracy': 0.7543859649122807,
 'COLA:ends_with_question_mark/COLA_valid/matthews_corr': 0.4818181818181818}

## Finetune model on slice of interest

In [14]:
from metal.mmtl.metal_model import MetalModel

from metal.mmtl.trainer import MultitaskTrainer
trainer = MultitaskTrainer(seed=SEED)

In [15]:
# only finetune on the slices, not the original task
for p in payloads_slice:
    p.task_names.remove('COLA')
payloads_slice

[Payload(COLA_train: tasks=[COLA:ends_with_question_mark], split=train),
 Payload(COLA_valid: tasks=[COLA:ends_with_question_mark], split=valid),
 Payload(COLA_test: tasks=[COLA:ends_with_question_mark], split=test)]

NOTE: We are training on a different set of payloads than we initialized the model with.

In [16]:
trainer.train_model(
    model,
    payloads_slice,
    checkpoint_metric="COLA/COLA_valid/matthews_corr",
    checkpoint_metric_mode="max",
    writer="tensorboard",
    optimizer="adamax",
    lr=5e-5,
    l2=1e-2,
    log_every=0.05, 
    score_every=0.1,
    n_epochs=3,
    progress_bar=True
)

Adding missing slice heads to train {'COLA:ends_with_question_mark'}
Beginning train loop.
Expecting a total of approximately 8560 examples and 535 batches per epoch from 1 payload(s) in the train split.
Writing config to /dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/01_59_39/config.json


HBox(children=(IntProgress(value=0, max=535), HTML(value='')))



[0.05 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.12e-01] model:[train/loss=3.12e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.10 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.61e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.50e-01] model:[train/loss=1.61e-01, train/lr=5.00e-05, valid/glue=nan]


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


[0.15 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.78e-01] model:[train/loss=2.78e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.20 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=7.20e-02, COLA_valid/accuracy=7.19e-01, COLA_valid/matthews_corr=3.87e-01] model:[train/loss=7.20e-02, train/lr=5.00e-05, valid/glue=nan]




[0.25 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.81e-01] model:[train/loss=3.81e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.30 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.29e-01, COLA_valid/accuracy=7.19e-01, COLA_valid/matthews_corr=3.87e-01] model:[train/loss=2.29e-01, train/lr=5.00e-05, valid/glue=nan]




[0.35 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.93e-01] model:[train/loss=3.93e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.40 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.72e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=1.72e-01, train/lr=5.00e-05, valid/glue=nan]




[0.45 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.75e-01] model:[train/loss=3.75e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.50 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.68e-01, COLA_valid/accuracy=7.19e-01, COLA_valid/matthews_corr=3.87e-01] model:[train/loss=1.68e-01, train/lr=5.00e-05, valid/glue=nan]




[0.56 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.40e-01] model:[train/loss=3.40e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.61 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.20e-01, COLA_valid/accuracy=7.19e-01, COLA_valid/matthews_corr=3.87e-01] model:[train/loss=2.20e-01, train/lr=5.00e-05, valid/glue=nan]




[0.66 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.82e-01] model:[train/loss=3.82e-01, train/lr=5.00e-05]



HBox(children=(IntProgress(value=0, max=535), HTML(value='')))

Evaluating 57 / 1042 active labels
[0.71 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.43e-01, COLA_valid/accuracy=7.19e-01, COLA_valid/matthews_corr=3.87e-01] model:[train/loss=1.43e-01, train/lr=5.00e-05, valid/glue=nan]




[0.76 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.38e-01] model:[train/loss=2.38e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.81 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.00e-01, COLA_valid/accuracy=7.19e-01, COLA_valid/matthews_corr=3.87e-01] model:[train/loss=3.00e-01, train/lr=5.00e-05, valid/glue=nan]




[0.86 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.20e-01] model:[train/loss=2.20e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[0.91 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.17e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=3.17e-01, train/lr=5.00e-05, valid/glue=nan]




[0.96 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=4.56e-01] model:[train/loss=4.56e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.01 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.31e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=1.31e-01, train/lr=5.00e-05, valid/glue=nan]




[1.06 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.92e-01] model:[train/loss=3.92e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.11 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.25e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=3.25e-01, train/lr=5.00e-05, valid/glue=nan]




[1.16 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.33e-01] model:[train/loss=3.33e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.21 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.24e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=2.24e-01, train/lr=5.00e-05, valid/glue=nan]




[1.26 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.49e-01] model:[train/loss=3.49e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.31 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.29e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=2.29e-01, train/lr=5.00e-05, valid/glue=nan]




[1.36 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.17e-01] model:[train/loss=2.17e-01, train/lr=5.00e-05]



HBox(children=(IntProgress(value=0, max=535), HTML(value='')))

Evaluating 57 / 1042 active labels
[1.41 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.41e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=2.41e-01, train/lr=5.00e-05, valid/glue=nan]




[1.46 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.30e-01] model:[train/loss=1.30e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.51 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.29e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=2.29e-01, train/lr=5.00e-05, valid/glue=nan]




[1.56 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.60e-01] model:[train/loss=2.60e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.61 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.26e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=3.26e-01, train/lr=5.00e-05, valid/glue=nan]




[1.67 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=5.38e-01] model:[train/loss=5.38e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.72 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=4.83e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=4.83e-01, train/lr=5.00e-05, valid/glue=nan]




[1.77 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.67e-01] model:[train/loss=3.67e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.82 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=8.09e-02, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=8.09e-02, train/lr=5.00e-05, valid/glue=nan]




[1.87 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.95e-01] model:[train/loss=3.95e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[1.92 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.83e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=2.83e-01, train/lr=5.00e-05, valid/glue=nan]




[1.97 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=2.65e-01] model:[train/loss=2.65e-01, train/lr=5.00e-05]
Evaluating 57 / 1042 active labels
[2.02 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=1.03e-01, COLA_valid/accuracy=7.02e-01, COLA_valid/matthews_corr=3.45e-01] model:[train/loss=1.03e-01, train/lr=5.00e-05, valid/glue=nan]




[2.07 epo]: COLA:[train/loss=None] COLA:ends_with_question_mark:[train/loss=3.86e-01] model:[train/loss=3.86e-01, train/lr=5.00e-05]

Finished training
Evaluating 615 / 8550 active labels
Evaluating 57 / 1042 active labels
Evaluating 53 / 1063 active labels
{'COLA:ends_with_question_mark/COLA_test/accuracy': 0.0,
 'COLA:ends_with_question_mark/COLA_test/matthews_corr': 0.0,
 'COLA:ends_with_question_mark/COLA_train/accuracy': 0.9024390243902439,
 'COLA:ends_with_question_mark/COLA_train/matthews_corr': 0.7717915771079772,
 'COLA:ends_with_question_mark/COLA_valid/accuracy': 0.7017543859649122,
 'COLA:ends_with_question_mark/COLA_valid/matthews_corr': 0.3445843938031584,
 'model/None/glue': nan}
Cleaning checkpoints
Writing metrics to /dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/01_59_39/metrics.json
Writing log to /dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/01_59_39/log.json


  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)


Full model saved at /dfs/scratch0/vschen/metal-mmtl/logs/2019_03_17/01_59_39/model.pkl


#### Did we improve?

In [17]:
from metal.mmtl.metal_model import MetalModel
model.score(payloads_slice[1])

Evaluating 57 / 1042 active labels


{'COLA:ends_with_question_mark/COLA_valid/accuracy': 0.7017543859649122,
 'COLA:ends_with_question_mark/COLA_valid/matthews_corr': 0.3445843938031584}