<!--<badge>--><a href="https://colab.research.google.com/github/ankur-98/BERT_GLUE/blob/main/multi_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a><!--</badge>-->

# For colab run:

In [1]:
# Switch to GPU runtime
! git clone https://github.com/ankur-98/BERT_GLUE.git
import os 
os.chdir("BERT_GLUE")
! pip install datasets transformers

Cloning into 'BERT_GLUE'...
remote: Enumerating objects: 107, done.[K
remote: Counting objects: 100% (107/107), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 107 (delta 53), reused 46 (delta 21), pack-reused 0[K
Receiving objects: 100% (107/107), 32.42 KiB | 310.00 KiB/s, done.
Resolving deltas: 100% (53/53), done.
Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/46/1a/b9f9b3bfef624686ae81c070f0a6bb635047b17cdb3698c7ad01281e6f9a/datasets-1.6.2-py3-none-any.whl (221kB)
[K     |████████████████████████████████| 225kB 4.3MB/s 
[?25hCollecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/b0/9e/5b80becd952d5f7250eaf8fc64b957077b12ccfe73e9c03d37146ab29712/transformers-4.6.0-py3-none-any.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 31.4MB/s 
[?25hCollecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/7d/4f/0a862cad26aa2ed7a7cd87178cbbfa824fc1383e472d63596a0d01837

# Imports

In [1]:
import torch
from tqdm.auto import tqdm
from dataloader import get_dataloader
from transformers import BertModel
from model import BERTClassifierModel
from train import training_step, eval_step
from util import *

# Configs
### Tasks: {"cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"}

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint="bert-base-uncased"
tasks = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
batch_size=1
steps = 200
lr = 2e-5
lr_scheduler_type = "linear" # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"

# Load Dataloader and Pre-trained BERT Model

In [None]:
num_labels = [3 if task.startswith("mnli") else 1 if task=="stsb" else 2 for task in tasks]
train_epoch_iterator = [get_dataloader(task, model_checkpoint, "train", batch_size=batch_size) for task in tasks]
eval_epoch_iterator = [get_dataloader(task, model_checkpoint, "validation", batch_size=batch_size) for task in tasks]
BERT_model = BertModel.from_pretrained(model_checkpoint)
models = [BERTClassifierModel(BERT_model, num_labels=num_label, task=task).to(device) for num_label, task in zip(num_labels, tasks)]

# Optimizer and LR Scheduler

In [None]:
Optimizers = [create_optimizer(model, learning_rate=lr) for model in models]
LR_schedulers = [create_scheduler(Optimizer, lr_scheduler_type) for Optimizer in Optimizers]
Metrics, Metrics_1 = zip(*[get_metrics(task) for task in tasks])
tr_loss = []
eval_loss = []
tr_metrics = [[] for task in tasks]
eval_metrics = [[] for task in tasks]
tr_metrics_1 = [[] for task in tasks]
eval_metrics_1 = [[] for task in tasks]

# Training loop

In [None]:
global_steps = 0
max_step_in_dataloaders = max([len(iterator) for iterator in train_epoch_iterator])
trange = range(max_step_in_dataloaders)
loss_pbar = tqdm(trange, initial=global_steps, total=steps, desc=f"step summary:")
pbars = [tqdm(trange, initial=global_steps, total=steps, desc=f"{task}") for task in tasks]
for e in range((steps//max_step_in_dataloaders)+1):

    for step in trange:
        global_steps += 1
        step_loss = 0
        for i, task in enumerate(tasks):
            iterator = iter(train_epoch_iterator[i])
            pbars[i].update()
            
            inputs = prepare_inputs(iterator.next(), device)
            step_task_loss, step_metric, step_metric_1 = training_step(models[i], inputs, 
                                                                  Optimizers[i], LR_schedulers[i], 
                                                                  Metrics[i], Metrics_1[i])

            step_loss += step_task_loss
            tr_metrics[i].append(torch.tensor(list(step_metric.values())[0]))
            if Metrics_1[i] is not None: tr_metrics_1[i].append(torch.tensor(list(step_metric_1.values())[0]))
            
            step_evaluation = {}
            step_evaluation[f"{Metrics[i].__class__.__name__}"] = torch.stack(tr_metrics[i])[-len(train_epoch_iterator[i]):].mean().item()
            if Metrics_1[i] is not None:
                step_evaluation[f"{Metrics_1[i].__class__.__name__}"] = torch.stack(tr_metrics_1[i])[-len(train_epoch_iterator[i]):].mean().item()
            pbars[i].set_postfix(step_evaluation)
        loss_pbar.update()
        tr_loss.append(step_loss)
        step_evaluation = {}
        step_evaluation['loss'] = torch.stack(tr_loss[-len(train_epoch_iterator[0]):]).mean().item()
        loss_pbar.set_postfix(step_evaluation)
        
        if global_steps == steps:
            break

# Evaluation

In [None]:
print(f"Evaluation begins in batches of {batch_size}..")
trange = [len(iterator) for iterator in eval_epoch_iterator]
loss_pbar = tqdm(sum(trange), desc=f"step summary:")
pbars = [tqdm(range(trange[i]), desc=f"{task}") for i, task in enumerate(tasks)]
for i, task in enumerate(tasks):
    for step in range(trange[i]):
        step_loss = 0
        iterator = iter(eval_epoch_iterator[i])
        pbars[i].update()
        
        inputs = prepare_inputs(iterator.next(), device)
        step_task_loss, step_metric, step_metric_1 = eval_step(models[i], inputs, 
                                                               Metrics[i], Metrics_1[i])

        step_loss += step_task_loss
        eval_metrics[i].append(torch.tensor(list(step_metric.values())[0]))
        if Metrics_1[i] is not None: eval_metrics_1[i].append(torch.tensor(list(step_metric_1.values())[0]))
        
        step_evaluation = {}
        step_evaluation[f"{Metrics[i].__class__.__name__}"] = torch.stack(eval_metrics[i]).mean().item()
        if Metrics_1[i] is not None:
            step_evaluation[f"{Metrics_1[i].__class__.__name__}"] = torch.stack(eval_metrics_1[i]).mean().item()
        pbars[i].set_postfix(step_evaluation)
    loss_pbar.update()
    eval_loss.append(step_loss)
    step_evaluation = {}
    step_evaluation['loss'] = torch.stack(eval_loss).mean().item()
    loss_pbar.set_postfix(step_evaluation)