In [45]:
from transformers import AutoTokenizer, DataCollatorWithPadding, AdamW, AutoModelForSequenceClassification, get_scheduler
from datasets import load_dataset
from accelerate import Accelerator
import evaluate
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

### Data Processing

In [46]:
# get dataset
raw_datasets = load_dataset("glue", "mrpc")
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1725
    })
})

In [47]:
# check values of dataset
raw_datasets["validation"][86]

{'sentence1': 'He was arrested Friday night at an Alpharetta seafood restaurant while dining with his wife , singer Whitney Houston .',
 'sentence2': 'He was arrested again Friday night at an Alpharetta restaurant where he was having dinner with his wife .',
 'label': 1,
 'idx': 796}

In [48]:
# get model name and its tokenizer
checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [49]:
# define a tokenizer function for mapping to dataset (all values in dataset tokenized optimally through mapping)
def tokenizer_fn(dataset):
    return tokenizer(dataset["sentence1"], dataset["sentence2"], truncation=True)

tokenized_datasets = raw_datasets.map(tokenizer_fn, batched=True)
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1725
    })
})

In [50]:
# define collator (padding)
collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [51]:
# check lengths of samples
samples = tokenized_datasets["train"][:5]
samples = {k: v for k, v in samples.items() if k not in ["idx", "sentence1", "sentence2"]}
[len(x) for x in samples["input_ids"]] # different lengths

[50, 59, 47, 67, 59]

In [52]:
# test collator (padding should make each sample the same length)
batch = collator(samples)
{k: v.shape for k, v in batch.items()} # with padding now all same length

{'input_ids': torch.Size([5, 67]),
 'token_type_ids': torch.Size([5, 67]),
 'attention_mask': torch.Size([5, 67]),
 'labels': torch.Size([5])}

In [53]:
# remove unnecessary columns for training
tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])

# rename for continuity
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

# set to torch tensor
tokenized_dataset = tokenized_datasets.set_format("torch")

tokenized_datasets["train"].column_names

['labels', 'input_ids', 'token_type_ids', 'attention_mask']

In [54]:
batch_size = 16

# make dataloaders from dataset
train_dl = DataLoader(tokenized_datasets["train"], shuffle=True, batch_size=batch_size, collate_fn=collator)
test_dl = DataLoader(tokenized_datasets["test"], shuffle=True, batch_size=batch_size, collate_fn=collator)

In [55]:
# check shapes for dataloaders (batch, sample length)
batch = next(iter(train_dl))
{k: v.shape for k, v in batch.items()}

{'labels': torch.Size([16]),
 'input_ids': torch.Size([16, 78]),
 'token_type_ids': torch.Size([16, 78]),
 'attention_mask': torch.Size([16, 78])}

### Load Model

In [56]:
# define accelerator that eases train-test process
accelerator = Accelerator()

In [57]:
# define model
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [58]:
# define optimizer
optimizer = AdamW(model.parameters(), lr=0.00001)



In [59]:
# set dataloaders, model and optimizer to accelerator (to gpu or smth)
train_dl, test_dl, model, optimizer = accelerator.prepare(train_dl, test_dl, model, optimizer)

In [60]:
# define epochs
epochs = 3
training_steps = epochs * len(train_dl)

In [61]:
# define scheduler for training
scheduler = get_scheduler(
    "linear",
    optimizer = optimizer,
    num_warmup_steps=0,
    num_training_steps=training_steps
)

### Train Model

In [62]:
# get metrics for dataset
metric = evaluate.load("glue", "mrpc")

In [63]:
prog = tqdm(range(epochs*training_steps))

for epoch in range(epochs):
    
    # TRAINING
    model.train()

    for batch in train_dl:

        # forward pass
        output = model(**batch)
        loss = output.loss
        
        # metrics
        logits = output.logits
        preds = torch.argmax(logits, dim=1)
        metric.add_batch(predictions=preds, references=batch["labels"])
        
        # backprop
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        prog.update(1)

    print("Training: ", metric.compute())

    
    # TESTING
    model.eval()

    with torch.inference_mode():
        for batch in test_dl:

            # forward pass
            output = model(**batch)
            
            # metrics
            logits = output.logits
            preds = torch.argmax(logits, dim=1)
            metric.add_batch(predictions=preds, references=batch["labels"])

    print("Test: ", metric.compute())

 11%|█         | 229/2070 [00:40<05:24,  5.68it/s]

Training:  {'accuracy': 0.697928026172301, 'f1': 0.8044475820684787}


 11%|█         | 231/2070 [00:45<43:33,  1.42s/it]

Test:  {'accuracy': 0.7547826086956522, 'f1': 0.8354725787631272}


 22%|██▏       | 460/2070 [01:25<04:03,  6.61it/s]

Training:  {'accuracy': 0.8034351145038168, 'f1': 0.860406582768635}


 22%|██▏       | 461/2070 [01:31<48:27,  1.81s/it]

Test:  {'accuracy': 0.8034782608695652, 'f1': 0.8606658446362515}


 33%|███▎      | 689/2070 [02:11<04:14,  5.43it/s]

Training:  {'accuracy': 0.888495092693566, 'f1': 0.9182163567286543}
Test:  {'accuracy': 0.8127536231884058, 'f1': 0.8645702306079665}
