In [2]:
from functools import partial
from pathlib import Path

import torch as t
import wandb as wb
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from dataclasses import dataclass, asdict

import evaluate
from tqdm.auto import tqdm
from transformers import AutoModelForSequenceClassification

In [3]:
MODEL_NAME = "bert-base-uncased"
DEVICE = t.device("mps")
DATAROOT = Path.home()/"mldata"

In [9]:
@dataclass
class Hyperparams:
    n_epochs: int
    batch_size: int
    lr: float

In [4]:
def collate(tokenizer, xs):
    """
    List of instances, where each instance is a dict that looks like -
    {
        "sentence1": <sentence here>,
        "sentence2": <sentence here>,
        "label": 0,
        "idx": 1
    }
    """
    s1s, s2s, labels = [], [], []
    for x in xs:
        s1s.append(x["sentence1"])
        s2s.append(x["sentence2"])
        labels.append(x["label"])

    batch = tokenizer(s1s, s2s, truncation=True, padding=True, return_tensors="pt")
    batch["labels"] = t.tensor(labels)
    return batch

In [5]:
def dataloaders(train_batch_size, eval_batch_size):
    mrpc = load_dataset("glue", "mrpc")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    collate_fn = partial(collate, tokenizer)

    traindl = DataLoader(
        mrpc["train"],
        shuffle=True, 
        batch_size=train_batch_size, 
        collate_fn=collate_fn
    )
    valdl = DataLoader(
        mrpc["validation"],
        shuffle=False,
        batch_size=eval_batch_size,
        collate_fn=collate_fn,
    )
    testdl = DataLoader(
        mrpc["test"], 
        shuffle=False, 
        batch_size=eval_batch_size, 
        collate_fn=collate_fn
    )
    return traindl, valdl, testdl

In [7]:
def eval(model, valdl, global_step):
    metric = evaluate.load("glue", "mrpc")
    losses = []
    model.eval()
    for batch in tqdm(valdl):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        with t.no_grad():
            outputs = model(**batch)
        losses.append(outputs.loss.item())
        logits = outputs.logits
        predictions = t.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    eval_metrics = metric.compute()
    avg_loss = t.mean(t.tensor(losses)).item()
    wb.log(
        {
            "val/accuracy": eval_metrics["accuracy"],
            "val/F1": eval_metrics["f1"],
            "val/loss": avg_loss
        },
        step=global_step
    )

In [19]:
def train(model, traindl, global_step, optim, lr_scheduler):
    metric = evaluate.load("glue", "mrpc")
    batch_losses = []
    model.train()
    with t.enable_grad():
        for batch in tqdm(traindl):
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            
            optim.zero_grad()
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optim.step()
            lr_scheduler.step()

            preds = t.argmax(outputs.logits, dim=-1)
            metric.add_batch(predictions=preds, references=batch["labels"])
            batch_losses.append(loss.detach().item())

            global_step += 1
        
    epoch_loss = t.mean(t.tensor(batch_losses))
    train_metrics = metric.compute()
    wb.log(
        {
            "loss/train": epoch_loss.item(),
            "accuracy/train": train_metrics["accuracy"],
            "f1/train": train_metrics["f1"]
        },
        step=global_step
    )
    return global_step



In [21]:
hparams = Hyperparams(n_epochs=3, batch_size=8, lr=5e-5)

run = wb.init(
    project="finetune-mrpc",
    config=asdict(hparams)
)

traindl, valdl, testdl = dataloaders(hparams.batch_size, hparams.batch_size)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model.to(DEVICE)
optim = t.optim.AdamW(model.parameters(), lr=hparams.lr)
num_training_steps = hparams.n_epochs * len(traindl)
lr_scheduler = t.optim.lr_scheduler.LinearLR(optimizer=optim, start_factor=1., end_factor=0.3, total_iters=num_training_steps)

wb.watch(model.classifier, log="all", log_freq=100)

global_step = 1
for epoch in range(hparams.n_epochs):
    global_step = train(model, traindl, global_step, optim, lr_scheduler)
    eval(model, valdl, global_step)

wb.finish()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/459 [00:00<?, ?it/s]

  0%|          | 0/51 [00:00<?, ?it/s]

  0%|          | 0/459 [00:00<?, ?it/s]

python(15093) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15119) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15150) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15168) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15178) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15214) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15222) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15246) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15249) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15251) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15257) Malloc

  0%|          | 0/51 [00:00<?, ?it/s]

python(15746) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15764) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


  0%|          | 0/459 [00:00<?, ?it/s]

python(15775) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15820) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15834) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15839) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15842) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15858) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15861) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15864) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15870) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15871) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(15874) Malloc

  0%|          | 0/51 [00:00<?, ?it/s]

python(16242) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy/train,▁▅█
f1/train,▁▅█
loss/train,█▄▁
val/F1,█▁▇
val/accuracy,█▁▆
val/loss,▇▁█

0,1
accuracy/train,0.9452
f1/train,0.95934
loss/train,0.14761
val/F1,0.88235
val/accuracy,0.83333
val/loss,0.43604


In [20]:
wb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

