In [1]:
# Imports

In [2]:
    import torch
    from tqdm.auto import tqdm
    from dataloader import get_dataloader
    from transformers import BertForSequenceClassification
    from train import training_step
    from util import *

In [3]:
# Configs
### Tasks: {"cola","mnli","mnli-mm","mrpc","qnli","qqp","rte","sst2","stsb","wnli"}

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint="bert-base-uncased"
task = "cola"
batch_size=96
steps = 2000
lr = 2e-5

In [5]:
# Load Dataloader and Pre-trained BERT Model

In [6]:
num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
train_epoch_iterator = get_dataloader(task, model_checkpoint, "train", batch_size=batch_size)
model = BertForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels).to(device)

Reusing dataset glue (C:\Users\ankur\.cache\huggingface\datasets\glue\cola\1.0.0\dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 9/9 [00:00<00:00, 22.90ba/s]
100%|██████████| 2/2 [00:00<00:00, 57.14ba/s]
100%|██████████| 2/2 [00:00<00:00, 58.82ba/s]
DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence', 'token_type_ids'],
        num_rows: 8551
    })
    validation: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence', 'token_type_ids'],
        num_rows: 1043
    })
    test: Dataset({
        features: ['attention_mask', 'idx', 'input_ids', 'label', 'sentence', 'token_type_ids'],
        num_rows: 1063
    })
})
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decod

In [7]:
# Optimizer and LR Scheduler

In [8]:
Optimizer = create_optimizer(model, learning_rate=lr)
LR_scheduler = create_scheduler(Optimizer)
Metric, Metric_1 = get_metrics(task)
tr_loss = []
tr_metric = []
tr_metric_1 = []

Couldn't find file locally at matthews_correlation\matthews_correlation.py, or remotely at https://raw.githubusercontent.com/huggingface/datasets/1.6.2/metrics/matthews_correlation/matthews_correlation.py.
The file was picked from the master branch on github instead at https://raw.githubusercontent.com/huggingface/datasets/master/metrics/matthews_correlation/matthews_correlation.py.


In [9]:
# Training loop

In [10]:
global_steps = 0
trange = range(len(train_epoch_iterator))
pbar = tqdm(trange, initial=global_steps, total=steps)
for e in range((steps//len(train_epoch_iterator))+1):
    iterator = iter(train_epoch_iterator)
    for step in trange:
        global_steps += 1
        pbar.update()
        
        inputs = prepare_inputs(iterator.next(), device)
        step_loss, step_metric, step_metric_1 = training_step(model, inputs, Optimizer, LR_scheduler, Metric, Metric_1)
        tr_loss.append(step_loss)
        tr_metric.append(torch.tensor(list(step_metric.values())[0]))
        if Metric_1 is not None: tr_metric_1.append(torch.tensor(list(step_metric_1.values())[0]))
        
        step_evaluation = {}
        step_evaluation['loss'] = torch.stack(tr_loss[-len(train_epoch_iterator):]).mean().item()
        step_evaluation[f"{Metric.__class__.__name__}"] = torch.stack(tr_metric)[-len(train_epoch_iterator):].mean().item()
        if Metric_1 is not None:
            step_evaluation[f"{Metric_1.__class__.__name__}"] = torch.stack(tr_metric_1)[-len(train_epoch_iterator):].mean().item()
        pbar.set_postfix(step_evaluation)
        
        if global_steps == steps:
            break

100%|██████████| 200/200 [01:47<00:00,  1.51it/s, loss=0.605, MatthewsCorelation=-.000974]