In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from metal.end_model import EndModel
from metal.mmtl.utils.dataset_utils import get_all_dataloaders
from metal.mmtl.dataset import QNLIDataset
from metal.mmtl.modules import BertEncoder
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.scorer import Scorer
from metal.mmtl.task import Task
from metal.mmtl.trainer import MultitaskTrainer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


### Config

In [8]:
bert_model = 'bert-base-uncased'
bert_model_output_shape = 768
max_len = 512
batch_size = 16
split_prop = 0.8
trainer_config = {
    "verbose": True,
    "device": "cuda",
    "loss_fn_reduction": "mean",
    "progress_bar": True,
    #"data_loader_config": {"batch_size": 32, "num_workers": 1, "shuffle": True}, ## TODO? 
    "n_epochs": 1,
    # 'grad_clip': 1.0,  ## TODO? 
    "l2": 0.01,
    "optimizer_config": {
        "optimizer": "adam",
        "optimizer_common": {"lr": 1e-5},
        "adam_config": {"betas": (0.9, 0.999)},
    },
    "lr_scheduler": "exponential", # reduce_on_plateau  ## TODO? Warmup
    "lr_scheduler_config": {
        "lr_freeze": 0,
        # Scheduler - exponential
        "exponential_config": {"gamma": 0.9},  # decay rate
        # Scheduler - reduce_on_plateau
        "plateau_config": {
            "factor": 0.5,
            "patience": 10,
            "threshold": 0.0001,
            "min_lr": 1e-4,
        },
    },
    # Logger (see metal/logging/logger.py for descriptions)
    "logger": True,
    "logger_config": {
        "log_unit": "epochs",  # ['seconds', 'examples', 'batches', 'epochs']
        "log_every": 0.05,
        "score_every": 0.1,
    },# Checkpointer (see metal/logging/checkpointer.py for descriptions)
    "checkpoint": True,  # If True, checkpoint models when certain conditions are met
    "checkpoint_config": {
        "checkpoint_every": 0,  # Save a model checkpoint every this many log_units
        "checkpoint_best": True,
        # "checkpoint_final": False,  # Save a model checkpoint at the end of training
        "checkpoint_metric": "ranking/valid/accuracy",
        "checkpoint_metric_mode": "max",
        "checkpoint_dir": f"{os.environ['METALHOME']}/checkpoints/qnli_single",
        "checkpoint_runway": 0,
    },
}

In [4]:
#dataloaders = get_all_dataloaders(
#    "QNLI", bert_model,
#    train_dev_split_prop=split_prop,
#    max_len=max_len
#)
dataloaders = {}
for split in ['test', 'dev']:
    dataset = QNLIDataset(
        split=split,
        bert_model=bert_model,
        max_len=max_len
    )
    dataloaders[split] = dataset.get_dataloader(batch_size=batch_size, shuffle=True)
dataloaders['train'] = dataloaders['dev']
dataloaders['valid'] = dataloaders['dev']

100%|██████████| 5463/5463 [00:04<00:00, 1224.48it/s]
100%|██████████| 5463/5463 [00:04<00:00, 1234.72it/s]


In [5]:
bert_encoder = BertEncoder(bert_model)

In [6]:
from functools import partial
from typing import Callable, List
import torch.nn.functional as F

ranking_head = nn.Linear(in_features=bert_model_output_shape, out_features=2, bias=False)
ranking_task = Task(
    name="ranking",
    data_loaders=dataloaders, 
    input_module=bert_encoder,
    head_module=ranking_head,
    scorer=Scorer(standard_metrics=["accuracy"]),
    loss_hat_func= lambda X, Y: F.cross_entropy(X, Y - 1),
    output_hat_func=partial(F.softmax, dim=1)
)

In [9]:
tasks = [ranking_task]
model = MetalModel(tasks, verbose=False)
trainer = MultitaskTrainer()
trainer.train_model(
    model,
    tasks,
    **trainer_config)

Using GPU...


[0.050521691378363535 epo]: TRAIN:[loss=0.641]
{'ranking/loss': 0.6414394093596417, 'train/loss': 0.6414394093596417}
ranking/valid/accuracy
[0.10104338275672707 epo]: TRAIN:[loss=0.523] VALID:[ranking/accuracy=0.813]
{'ranking/loss': 0.5226242652405864, 'train/loss': 0.5226242652405864, 'ranking/valid/accuracy': 0.8125572030020135}
ranking/valid/accuracy
hello
Saving model at iteration 0.10104338275672707 with best (max) score 0.813
[0.1515650741350906 epo]: TRAIN:[loss=0.466]
{'ranking/loss': 0.46633189868019975, 'train/loss': 0.46633189868019975}
ranking/valid/accuracy

Restoring best model from iteration 0.10104338275672707 with score 0.813
Finished Training
{'ranking/valid/accuracy': 0.8125572030020135}


In [17]:
for (X, Y) in dataloaders['dev']:
    X = [x.cuda() for x in X]
    print(model(X, ['ranking']))
    print(model.calculate_loss(X, Y.cuda(), ['ranking']))    
    print(model.calculate_output(X, ['ranking']))    
    break

{'ranking': tensor([[ 0.3899, -0.5400],
        [ 0.4240, -0.5404],
        [ 0.3580, -0.6234],
        [ 0.4644, -0.6587],
        [-0.8041,  1.6927],
        [-0.4173,  0.9388],
        [-0.3887,  0.4791],
        [ 0.4736, -0.7310],
        [-0.5424,  0.5344],
        [-0.2840,  1.0988],
        [-0.3985,  0.9711],
        [-0.5806,  1.2527],
        [ 0.7220, -0.9857],
        [ 0.2798, -0.3332],
        [ 0.3106, -0.4223],
        [-0.4811,  1.6733]], device='cuda:0', grad_fn=<MmBackward>)}
{'ranking': tensor(0.3953, device='cuda:0', grad_fn=<NllLossBackward>)}
{'ranking': tensor([[0.7171, 0.2829],
        [0.7240, 0.2760],
        [0.7274, 0.2726],
        [0.7546, 0.2454],
        [0.0761, 0.9239],
        [0.2049, 0.7951],
        [0.2957, 0.7043],
        [0.7693, 0.2307],
        [0.2541, 0.7459],
        [0.2006, 0.7994],
        [0.2027, 0.7973],
        [0.1378, 0.8622],
        [0.8465, 0.1535],
        [0.6486, 0.3514],
        [0.6754, 0.3246],
        [0.1039, 0.8961]]