# Practical classification with pre-trained BERT

In this notebook I download pre-trained BERT model and fine-tune it with high-level HuggingFace tools.

There is another notebook, doing the same with lower-level PyTorch tools only.

## References:
* https://huggingface.co/course/chapter3/4?fw=pt - HuggingFace transformers course reference

In [None]:
# minimal example of using a pre-trained model for classification
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
sequences = ["I've been waiting for a HuggingFace course my whole life.", "So have I!"]

tokens = tokenizer(sequences, padding=True, truncation=True, return_tensors="pt")
output = model(**tokens)

torch.nn.functional.softmax(output.logits, dim=1)

In [1]:
import pandas as pd


essays = pd.read_csv("./data/essays.csv")

essays.loc[essays['cEXT'] == 'n', 'cEXT'] = 0
essays.loc[essays['cEXT'] == 'y', 'cEXT'] = 1

essays.loc[essays['cNEU'] == 'n', 'cNEU'] = 0
essays.loc[essays['cNEU'] == 'y', 'cNEU'] = 1

essays.loc[essays['cAGR'] == 'n', 'cAGR'] = 0
essays.loc[essays['cAGR'] == 'y', 'cAGR'] = 1

essays.loc[essays['cCON'] == 'n', 'cCON'] = 0
essays.loc[essays['cCON'] == 'y', 'cCON'] = 1

essays.loc[essays['cOPN'] == 'n', 'cOPN'] = 0
essays.loc[essays['cOPN'] == 'y', 'cOPN'] = 1

essays.astype({'cEXT': 'int32', 'cNEU': 'int32', 'cAGR': 'int32', 'cCON': 'int32', 'cOPN': 'int32'}).dtypes

essays

Unnamed: 0,#AUTHID,TEXT,cEXT,cNEU,cAGR,cCON,cOPN
0,1997_504851.txt,"Well, right now I just woke up from a mid-day ...",0,1,1,0,1
1,1997_605191.txt,"Well, here we go with the stream of consciousn...",0,0,1,0,0
2,1997_687252.txt,An open keyboard and buttons to push. The thin...,0,1,0,1,1
3,1997_568848.txt,I can't believe it! It's really happening! M...,1,0,1,1,0
4,1997_688160.txt,"Well, here I go with the good old stream of co...",1,0,1,0,1
...,...,...,...,...,...,...,...
2462,2004_493.txt,I'm home. wanted to go to bed but remembe...,0,1,0,1,0
2463,2004_494.txt,Stream of consiousnesssskdj. How do you s...,1,1,0,0,1
2464,2004_497.txt,"It is Wednesday, December 8th and a lot has be...",0,0,1,0,0
2465,2004_498.txt,"Man this week has been hellish. Anyways, now i...",0,1,0,0,1


In [2]:
import torch
from torch.utils.data import DataLoader, random_split, default_convert
from transformers import AdamW, AutoTokenizer, BertForSequenceClassification
from datasets import Dataset, DatasetDict


# prepare dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(essays):
    return tokenizer(essays["TEXT"], padding="max_length", truncation=True)  # , return_tensors="pt")

essays_dataset = Dataset.from_pandas(essays)
tokenized_dataset = essays_dataset.map(tokenize_function, batched=True, batch_size=8)
tokenized_dataset = tokenized_dataset.rename_column("TEXT", "text")
tokenized_dataset = tokenized_dataset.rename_column("cNEU", "labels")
tokenized_dataset = tokenized_dataset.remove_columns(['#AUTHID', 'text', 'cEXT', 'cAGR', 'cCON', 'cOPN'])

train_dataset, validation_dataset = random_split(tokenized_dataset, [2000, len(tokenized_dataset) - 2000])

ds = DatasetDict()
ds['train'] = train_dataset
ds['validation'] = validation_dataset

# vocab = tokenizer.get_vocab()
# ivocab = {v: k for k, v in vocab.items()}
print(ds['train'][0]['input_ids'])

train_dataloader = DataLoader(ds['train'], shuffle=True, batch_size=8)

  from .autonotebook import tqdm as notebook_tqdm
100%|█████████████████████████████████████████████████| 309/309 [00:01<00:00, 236.43ba/s]

[101, 2157, 2085, 1045, 2572, 1999, 2026, 19568, 2282, 13400, 2013, 4280, 1998, 8225, 2005, 2019, 3944, 3561, 3262, 2007, 5702, 1012, 2026, 5551, 2024, 4909, 1996, 6919, 4165, 1997, 2990, 3779, 1010, 2990, 7361, 3771, 3401, 1004, 4913, 12351, 1010, 1998, 2036, 8292, 18622, 4542, 1006, 1037, 3409, 5297, 4438, 1007, 1012, 1045, 2064, 2025, 2360, 2008, 1045, 2572, 19773, 2505, 3391, 2138, 2026, 8254, 25581, 2024, 3772, 2039, 2044, 1037, 2208, 1997, 4715, 2197, 2305, 1012, 1045, 3984, 2045, 2003, 2074, 2242, 2055, 1996, 5568, 2091, 2011, 1996, 6053, 3501, 23497, 2008, 8769, 2039, 2026, 2010, 15464, 10586, 1012, 1045, 2572, 3110, 1037, 2210, 11480, 1998, 13233, 2157, 2085, 1012, 1045, 8271, 2039, 2012, 2184, 1024, 4002, 2023, 2851, 2074, 2004, 2026, 2034, 2465, 2001, 3225, 1012, 2023, 2465, 2003, 2235, 1998, 2200, 2172, 5270, 2241, 1010, 2061, 1045, 2071, 2025, 3432, 13558, 2009, 1012, 1045, 2018, 2000, 5481, 2041, 1996, 2341, 1998, 1045, 2903, 2008, 2023, 5481, 2038, 3303, 2033, 2000, 2514




In [3]:
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, output_attentions=True)

# I'm running this on Apple Silicon. Activate Metal "mps" device, if available:
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")


torch.device("mps")
model.to(mps_device)

model.train()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [11]:
from transformers import get_scheduler
from tqdm.auto import tqdm
from torch.optim import AdamW


# parameters
num_epochs = 1  # 3
num_training_steps = num_epochs * len(train_dataloader)

cross_entropy_loss = torch.nn.CrossEntropyLoss().to(mps_device)

optimizer = AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)


# test on one batch
# batch = next(iter(train_dataloader))

# labels = batch["labels"]
# del batch["labels"]

# batch = {k: torch.transpose(torch.stack(default_convert(v)), 0, 1) for k, v in batch.items()}
# batch = {k: v.to(mps_device) for k, v in batch.items()}

# output = model(**batch)
# labels.to(mps_device)
# mps_labels = torch.as_tensor(labels, device=mps_device)

# loss = cross_entropy_loss(output.logits, mps_labels)
# loss.backward()


# progress bar
progress_bar = tqdm(range(num_training_steps))

# training
for epoch in range(num_epochs):
    for batch in train_dataloader:
        labels = batch["labels"].to(mps_device)
        mps_labels = torch.as_tensor(labels, device=mps_device)
        del batch["labels"]
        
        batch = {k: torch.transpose(torch.stack(default_convert(v)), 0, 1) for k, v in batch.items()}
        batch = {k: v.to(mps_device) for k, v in batch.items()}

        outputs = model(**batch, labels=mps_labels)
        loss, logits = outputs[:2]

        # loss = cross_entropy_loss(output.logits, mps_labels)
        print("Training loss: {0:.2f}".format(loss))
        loss.backward()        

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)


100%|██████████████████████████████████████████████████| 250/250 [48:20<00:00, 11.60s/it][A


Training loss: 0.69



  0%|▏                                                   | 1/250 [00:06<25:40,  6.19s/it][A

Training loss: 0.72



  1%|▍                                                   | 2/250 [00:11<24:18,  5.88s/it][A

Training loss: 0.67



  1%|▌                                                   | 3/250 [00:17<24:30,  5.95s/it][A

Training loss: 0.68



  2%|▊                                                   | 4/250 [00:23<24:03,  5.87s/it][A

Training loss: 0.73



  2%|█                                                   | 5/250 [00:29<23:32,  5.76s/it][A

Training loss: 0.66



  2%|█▏                                                  | 6/250 [00:34<23:19,  5.74s/it][A

Training loss: 0.71



  3%|█▍                                                  | 7/250 [00:40<23:06,  5.71s/it][A

Training loss: 0.70



  3%|█▋                                                  | 8/250 [00:46<22:53,  5.67s/it][A

Training loss: 0.68



  4%|█▊                                                  | 9/250 [00:51<22:45,  5.67s/it][A

Training loss: 0.71



  4%|██                                                 | 10/250 [00:57<23:00,  5.75s/it][A

Training loss: 0.67



  4%|██▏                                                | 11/250 [01:03<22:37,  5.68s/it][A

Training loss: 0.73



  5%|██▍                                                | 12/250 [01:08<22:27,  5.66s/it][A

Training loss: 0.65



  5%|██▋                                                | 13/250 [01:14<22:03,  5.58s/it][A

Training loss: 0.72



  6%|██▊                                                | 14/250 [01:20<22:16,  5.66s/it][A

Training loss: 0.74



  6%|███                                                | 15/250 [01:25<22:25,  5.72s/it][A

Training loss: 0.72



  6%|███▎                                               | 16/250 [01:31<22:18,  5.72s/it][A

Training loss: 0.74



  7%|███▍                                               | 17/250 [01:37<22:30,  5.80s/it][A

Training loss: 0.69



  7%|███▋                                               | 18/250 [01:43<22:49,  5.90s/it][A

Training loss: 0.69



  8%|███▉                                               | 19/250 [01:49<22:45,  5.91s/it][A

Training loss: 0.69



  8%|████                                               | 20/250 [01:55<22:30,  5.87s/it][A

Training loss: 0.71



  8%|████▎                                              | 21/250 [02:01<22:24,  5.87s/it][A

Training loss: 0.69



  9%|████▍                                              | 22/250 [02:07<22:19,  5.88s/it][A

Training loss: 0.70



  9%|████▋                                              | 23/250 [02:12<22:01,  5.82s/it][A

Training loss: 0.72



 10%|████▉                                              | 24/250 [02:18<21:43,  5.77s/it][A

Training loss: 0.70



 10%|█████                                              | 25/250 [02:24<21:45,  5.80s/it][A

Training loss: 0.67



 10%|█████▎                                             | 26/250 [02:30<21:54,  5.87s/it][A

Training loss: 0.70



 11%|█████▌                                             | 27/250 [02:36<21:36,  5.81s/it][A

Training loss: 0.69



 11%|█████▋                                             | 28/250 [02:41<21:12,  5.73s/it][A

Training loss: 0.68



 12%|█████▉                                             | 29/250 [02:47<21:05,  5.73s/it][A

Training loss: 0.69



 12%|██████                                             | 30/250 [02:53<20:48,  5.67s/it][A

Training loss: 0.74



 12%|██████▎                                            | 31/250 [02:58<20:40,  5.66s/it][A

Training loss: 0.68



 13%|██████▌                                            | 32/250 [03:04<20:37,  5.68s/it][A

Training loss: 0.74



 13%|██████▋                                            | 33/250 [03:10<20:30,  5.67s/it][A

Training loss: 0.72



 14%|██████▉                                            | 34/250 [03:15<20:30,  5.70s/it][A

Training loss: 0.67



 14%|███████▏                                           | 35/250 [03:21<20:14,  5.65s/it][A

Training loss: 0.73



 14%|███████▎                                           | 36/250 [03:26<20:07,  5.64s/it][A

Training loss: 0.68



 15%|███████▌                                           | 37/250 [03:32<20:07,  5.67s/it][A

Training loss: 0.70



 15%|███████▊                                           | 38/250 [03:38<19:41,  5.58s/it][A

Training loss: 0.70



 16%|███████▉                                           | 39/250 [03:43<19:47,  5.63s/it][A

Training loss: 0.70



 16%|████████▏                                          | 40/250 [03:49<19:39,  5.62s/it][A

Training loss: 0.69



 16%|████████▎                                          | 41/250 [03:55<19:47,  5.68s/it][A

Training loss: 0.70



 17%|████████▌                                          | 42/250 [04:00<19:33,  5.64s/it][A

Training loss: 0.68



 17%|████████▊                                          | 43/250 [04:06<19:19,  5.60s/it][A

Training loss: 0.73



 18%|████████▉                                          | 44/250 [04:12<19:24,  5.65s/it][A

Training loss: 0.69



 18%|█████████▏                                         | 45/250 [04:17<19:32,  5.72s/it][A

Training loss: 0.70



 18%|█████████▍                                         | 46/250 [04:24<19:51,  5.84s/it][A

Training loss: 0.68



 19%|█████████▌                                         | 47/250 [04:29<19:43,  5.83s/it][A

Training loss: 0.70



 19%|█████████▊                                         | 48/250 [04:35<19:45,  5.87s/it][A

Training loss: 0.71



 20%|█████████▉                                         | 49/250 [04:41<19:33,  5.84s/it][A

Training loss: 0.70



 20%|██████████▏                                        | 50/250 [04:47<19:25,  5.83s/it][A

Training loss: 0.72



 20%|██████████▍                                        | 51/250 [04:52<18:58,  5.72s/it][A

Training loss: 0.67



 21%|██████████▌                                        | 52/250 [04:58<18:54,  5.73s/it][A

Training loss: 0.66



 21%|██████████▊                                        | 53/250 [05:04<19:10,  5.84s/it][A

Training loss: 0.71



 22%|███████████                                        | 54/250 [05:10<19:26,  5.95s/it][A

Training loss: 0.64



 22%|███████████▏                                       | 55/250 [05:17<19:41,  6.06s/it][A

Training loss: 0.74



 22%|███████████▍                                       | 56/250 [05:23<19:38,  6.08s/it][A

Training loss: 0.71



 23%|███████████▋                                       | 57/250 [05:29<19:28,  6.06s/it][A

Training loss: 0.64



 23%|███████████▊                                       | 58/250 [05:35<19:19,  6.04s/it][A

Training loss: 0.67



 24%|████████████                                       | 59/250 [05:41<18:56,  5.95s/it][A

Training loss: 0.69



 24%|████████████▏                                      | 60/250 [05:47<18:50,  5.95s/it][A

Training loss: 0.79



 24%|████████████▍                                      | 61/250 [05:52<18:35,  5.90s/it][A

Training loss: 0.69



 25%|████████████▋                                      | 62/250 [05:58<18:22,  5.87s/it][A

Training loss: 0.65



 25%|████████████▊                                      | 63/250 [06:04<18:30,  5.94s/it][A

Training loss: 0.72



 26%|█████████████                                      | 64/250 [06:10<18:22,  5.93s/it][A

Training loss: 0.69



 26%|█████████████▎                                     | 65/250 [06:16<18:15,  5.92s/it][A

Training loss: 0.67



 26%|█████████████▍                                     | 66/250 [06:22<17:59,  5.87s/it][A

Training loss: 0.68



 27%|█████████████▋                                     | 67/250 [06:27<17:32,  5.75s/it][A

Training loss: 0.72



 27%|█████████████▊                                     | 68/250 [06:33<17:17,  5.70s/it][A

Training loss: 0.70



 28%|██████████████                                     | 69/250 [06:39<17:28,  5.80s/it][A

Training loss: 0.67



 28%|██████████████▎                                    | 70/250 [06:44<17:10,  5.73s/it][A

Training loss: 0.66



 28%|██████████████▍                                    | 71/250 [06:50<16:53,  5.66s/it][A

Training loss: 0.69



 29%|██████████████▋                                    | 72/250 [06:56<16:53,  5.69s/it][A

Training loss: 0.71



 29%|██████████████▉                                    | 73/250 [07:01<16:46,  5.69s/it][A

Training loss: 0.70



 30%|███████████████                                    | 74/250 [07:07<16:47,  5.72s/it][A

Training loss: 0.72



 30%|███████████████▎                                   | 75/250 [07:13<16:27,  5.64s/it][A

Training loss: 0.65



 30%|███████████████▌                                   | 76/250 [07:18<16:12,  5.59s/it][A

Training loss: 0.70



 31%|███████████████▋                                   | 77/250 [07:23<15:54,  5.52s/it][A

Training loss: 0.70



 31%|███████████████▉                                   | 78/250 [07:29<16:00,  5.58s/it][A

Training loss: 0.68



 32%|████████████████                                   | 79/250 [07:35<16:04,  5.64s/it][A

Training loss: 0.69



 32%|████████████████▎                                  | 80/250 [07:41<16:16,  5.74s/it][A

Training loss: 0.68



 32%|████████████████▌                                  | 81/250 [07:47<16:16,  5.78s/it][A

Training loss: 0.68



 33%|████████████████▋                                  | 82/250 [07:52<15:54,  5.68s/it][A

Training loss: 0.68



 33%|████████████████▉                                  | 83/250 [07:58<15:50,  5.69s/it][A

Training loss: 0.72



 34%|█████████████████▏                                 | 84/250 [08:04<15:55,  5.75s/it][A

Training loss: 0.69



 34%|█████████████████▎                                 | 85/250 [08:10<15:50,  5.76s/it][A

Training loss: 0.68



 34%|█████████████████▌                                 | 86/250 [08:16<16:24,  6.00s/it][A

Training loss: 0.69



 35%|█████████████████▋                                 | 87/250 [08:22<16:14,  5.98s/it][A

Training loss: 0.71



 35%|█████████████████▉                                 | 88/250 [08:28<16:10,  5.99s/it][A

Training loss: 0.74



 36%|██████████████████▏                                | 89/250 [08:34<15:56,  5.94s/it][A

Training loss: 0.72



 36%|██████████████████▎                                | 90/250 [08:40<15:41,  5.89s/it][A

Training loss: 0.71



 36%|██████████████████▌                                | 91/250 [08:45<15:29,  5.85s/it][A

Training loss: 0.69



 37%|██████████████████▊                                | 92/250 [08:51<15:17,  5.80s/it][A

Training loss: 0.68



 37%|██████████████████▉                                | 93/250 [08:57<14:51,  5.68s/it][A

Training loss: 0.70



 38%|███████████████████▏                               | 94/250 [09:02<14:41,  5.65s/it][A

Training loss: 0.68



 38%|███████████████████▍                               | 95/250 [09:08<14:26,  5.59s/it][A

Training loss: 0.65



 38%|███████████████████▌                               | 96/250 [09:13<14:20,  5.58s/it][A

Training loss: 0.68



 39%|███████████████████▊                               | 97/250 [09:19<14:10,  5.56s/it][A

Training loss: 0.69



 39%|███████████████████▉                               | 98/250 [09:25<14:21,  5.67s/it][A

Training loss: 0.63



 40%|████████████████████▏                              | 99/250 [09:30<14:18,  5.69s/it][A

Training loss: 0.67



 40%|████████████████████                              | 100/250 [09:36<14:17,  5.72s/it][A

Training loss: 0.79



 40%|████████████████████▏                             | 101/250 [09:42<14:29,  5.83s/it][A

Training loss: 0.72



 41%|████████████████████▍                             | 102/250 [09:48<14:20,  5.82s/it][A

Training loss: 0.68



 41%|████████████████████▌                             | 103/250 [09:54<14:02,  5.73s/it][A

Training loss: 0.76



 42%|████████████████████▊                             | 104/250 [09:59<13:53,  5.71s/it][A

Training loss: 0.69



 42%|█████████████████████                             | 105/250 [10:05<13:31,  5.59s/it][A

Training loss: 0.75



 42%|█████████████████████▏                            | 106/250 [10:10<13:27,  5.61s/it][A

Training loss: 0.77



 43%|█████████████████████▍                            | 107/250 [10:16<13:19,  5.59s/it][A

Training loss: 0.80



 43%|█████████████████████▌                            | 108/250 [10:21<13:14,  5.59s/it][A

Training loss: 0.68



 44%|█████████████████████▊                            | 109/250 [10:27<13:19,  5.67s/it][A

Training loss: 0.75



 44%|██████████████████████                            | 110/250 [10:33<13:28,  5.78s/it][A

Training loss: 0.69



 44%|██████████████████████▏                           | 111/250 [10:40<13:48,  5.96s/it][A

Training loss: 0.71



 45%|██████████████████████▍                           | 112/250 [10:45<13:36,  5.92s/it][A

Training loss: 0.71



 45%|██████████████████████▌                           | 113/250 [10:51<13:26,  5.89s/it][A

Training loss: 0.67



 46%|██████████████████████▊                           | 114/250 [10:57<13:14,  5.84s/it][A

Training loss: 0.69



 46%|███████████████████████                           | 115/250 [11:03<13:21,  5.94s/it][A

Training loss: 0.72



 46%|███████████████████████▏                          | 116/250 [11:09<13:00,  5.82s/it][A

Training loss: 0.71



 47%|███████████████████████▍                          | 117/250 [11:14<12:47,  5.77s/it][A

Training loss: 0.71



 47%|███████████████████████▌                          | 118/250 [11:20<12:37,  5.74s/it][A

Training loss: 0.67



 48%|███████████████████████▊                          | 119/250 [11:26<12:36,  5.78s/it][A

Training loss: 0.66



 48%|████████████████████████                          | 120/250 [11:32<12:30,  5.77s/it][A

Training loss: 0.71



 48%|████████████████████████▏                         | 121/250 [11:37<12:18,  5.72s/it][A

Training loss: 0.73



 49%|████████████████████████▍                         | 122/250 [11:43<12:16,  5.75s/it][A

Training loss: 0.70



 49%|████████████████████████▌                         | 123/250 [11:49<12:23,  5.85s/it][A

Training loss: 0.70



 50%|████████████████████████▊                         | 124/250 [11:55<12:03,  5.74s/it][A

Training loss: 0.65



 50%|█████████████████████████                         | 125/250 [12:00<11:57,  5.74s/it][A

Training loss: 0.69



 50%|█████████████████████████▏                        | 126/250 [12:06<11:49,  5.72s/it][A

Training loss: 0.71



 51%|█████████████████████████▍                        | 127/250 [12:12<11:47,  5.75s/it][A

Training loss: 0.63



 51%|█████████████████████████▌                        | 128/250 [12:17<11:33,  5.68s/it][A

Training loss: 0.69



 52%|█████████████████████████▊                        | 129/250 [12:23<11:34,  5.74s/it][A

Training loss: 0.70



 52%|██████████████████████████                        | 130/250 [12:29<11:19,  5.67s/it][A

Training loss: 0.66



 52%|██████████████████████████▏                       | 131/250 [12:34<11:10,  5.63s/it][A

Training loss: 0.72



 53%|██████████████████████████▍                       | 132/250 [12:40<11:04,  5.63s/it][A

Training loss: 0.65



 53%|██████████████████████████▌                       | 133/250 [12:45<10:55,  5.60s/it][A

Training loss: 0.74



 54%|██████████████████████████▊                       | 134/250 [12:51<10:52,  5.62s/it][A

Training loss: 0.68



 54%|███████████████████████████                       | 135/250 [12:57<10:50,  5.65s/it][A

Training loss: 0.74



 54%|███████████████████████████▏                      | 136/250 [13:03<10:54,  5.74s/it][A

Training loss: 0.77



 55%|███████████████████████████▍                      | 137/250 [13:08<10:45,  5.71s/it][A

Training loss: 0.73



 55%|███████████████████████████▌                      | 138/250 [13:14<10:32,  5.65s/it][A

Training loss: 0.67



 56%|███████████████████████████▊                      | 139/250 [13:20<10:27,  5.65s/it][A

Training loss: 0.68



 56%|████████████████████████████                      | 140/250 [13:25<10:27,  5.70s/it][A

Training loss: 0.79



 56%|████████████████████████████▏                     | 141/250 [13:31<10:26,  5.75s/it][A

Training loss: 0.69



 57%|████████████████████████████▍                     | 142/250 [13:37<10:21,  5.76s/it][A

Training loss: 0.69



 57%|████████████████████████████▌                     | 143/250 [13:43<10:18,  5.78s/it][A

Training loss: 0.70



 58%|████████████████████████████▊                     | 144/250 [13:48<10:02,  5.68s/it][A

Training loss: 0.65



 58%|████████████████████████████▉                     | 145/250 [13:54<09:55,  5.67s/it][A

Training loss: 0.67



 58%|█████████████████████████████▏                    | 146/250 [14:00<09:49,  5.67s/it][A

Training loss: 0.70



 59%|█████████████████████████████▍                    | 147/250 [14:06<09:52,  5.76s/it][A

Training loss: 0.72



 59%|█████████████████████████████▌                    | 148/250 [14:11<09:51,  5.80s/it][A

Training loss: 0.72



 60%|█████████████████████████████▊                    | 149/250 [14:18<10:00,  5.94s/it][A

Training loss: 0.67



 60%|██████████████████████████████                    | 150/250 [14:24<09:59,  5.99s/it][A

Training loss: 0.71



 60%|██████████████████████████████▏                   | 151/250 [14:30<09:47,  5.93s/it][A

Training loss: 0.71



 61%|██████████████████████████████▍                   | 152/250 [14:36<09:38,  5.91s/it][A

Training loss: 0.70



 61%|██████████████████████████████▌                   | 153/250 [14:42<09:37,  5.96s/it][A

Training loss: 0.72



 62%|██████████████████████████████▊                   | 154/250 [14:47<09:30,  5.94s/it][A

Training loss: 0.69



 62%|███████████████████████████████                   | 155/250 [14:54<09:30,  6.01s/it][A

Training loss: 0.68



 62%|███████████████████████████████▏                  | 156/250 [15:00<09:28,  6.05s/it][A

Training loss: 0.67



 63%|███████████████████████████████▍                  | 157/250 [15:05<09:11,  5.93s/it][A

Training loss: 0.71



 63%|███████████████████████████████▌                  | 158/250 [15:11<09:00,  5.88s/it][A

Training loss: 0.73



 64%|███████████████████████████████▊                  | 159/250 [15:17<08:50,  5.83s/it][A

Training loss: 0.72



 64%|████████████████████████████████                  | 160/250 [15:23<08:40,  5.78s/it][A

Training loss: 0.71



 64%|████████████████████████████████▏                 | 161/250 [15:29<08:38,  5.82s/it][A

Training loss: 0.66



 65%|████████████████████████████████▍                 | 162/250 [15:34<08:34,  5.85s/it][A

Training loss: 0.67



 65%|████████████████████████████████▌                 | 163/250 [15:40<08:33,  5.90s/it][A

Training loss: 0.70



 66%|████████████████████████████████▊                 | 164/250 [15:46<08:24,  5.86s/it][A

Training loss: 0.68



 66%|█████████████████████████████████                 | 165/250 [15:52<08:10,  5.77s/it][A

Training loss: 0.70



 66%|█████████████████████████████████▏                | 166/250 [15:57<07:59,  5.70s/it][A

Training loss: 0.71



 67%|█████████████████████████████████▍                | 167/250 [16:03<07:53,  5.70s/it][A

Training loss: 0.68



 67%|█████████████████████████████████▌                | 168/250 [16:09<07:44,  5.67s/it][A

Training loss: 0.71



 68%|█████████████████████████████████▊                | 169/250 [16:15<07:46,  5.75s/it][A

Training loss: 0.68



 68%|██████████████████████████████████                | 170/250 [16:20<07:36,  5.71s/it][A

Training loss: 0.65



 68%|██████████████████████████████████▏               | 171/250 [16:26<07:32,  5.73s/it][A

Training loss: 0.70



 69%|██████████████████████████████████▍               | 172/250 [16:32<07:24,  5.70s/it][A

Training loss: 0.70



 69%|██████████████████████████████████▌               | 173/250 [16:37<07:20,  5.72s/it][A

Training loss: 0.67



 70%|██████████████████████████████████▊               | 174/250 [16:43<07:13,  5.70s/it][A

Training loss: 0.70



 70%|███████████████████████████████████               | 175/250 [16:49<07:05,  5.67s/it][A

Training loss: 0.70



 70%|███████████████████████████████████▏              | 176/250 [16:54<06:55,  5.62s/it][A

Training loss: 0.71



 71%|███████████████████████████████████▍              | 177/250 [17:00<06:53,  5.67s/it][A

Training loss: 0.67



 71%|███████████████████████████████████▌              | 178/250 [17:06<06:50,  5.70s/it][A

Training loss: 0.70



 72%|███████████████████████████████████▊              | 179/250 [17:11<06:46,  5.73s/it][A

Training loss: 0.70



 72%|████████████████████████████████████              | 180/250 [17:17<06:38,  5.69s/it][A

Training loss: 0.68



 72%|████████████████████████████████████▏             | 181/250 [17:23<06:35,  5.73s/it][A

Training loss: 0.73



 73%|████████████████████████████████████▍             | 182/250 [17:29<06:30,  5.75s/it][A

Training loss: 0.67



 73%|████████████████████████████████████▌             | 183/250 [17:34<06:25,  5.76s/it][A

Training loss: 0.70



 74%|████████████████████████████████████▊             | 184/250 [17:40<06:22,  5.79s/it][A

Training loss: 0.71



 74%|█████████████████████████████████████             | 185/250 [17:46<06:18,  5.82s/it][A

Training loss: 0.70



 74%|█████████████████████████████████████▏            | 186/250 [17:52<06:10,  5.79s/it][A

Training loss: 0.69



 75%|█████████████████████████████████████▍            | 187/250 [17:58<06:04,  5.78s/it][A

Training loss: 0.70



 75%|█████████████████████████████████████▌            | 188/250 [18:03<05:54,  5.71s/it][A

Training loss: 0.68



 76%|█████████████████████████████████████▊            | 189/250 [18:09<05:48,  5.72s/it][A

Training loss: 0.71



 76%|██████████████████████████████████████            | 190/250 [18:15<05:41,  5.69s/it][A

Training loss: 0.69



 76%|██████████████████████████████████████▏           | 191/250 [18:20<05:37,  5.72s/it][A

Training loss: 0.71



 77%|██████████████████████████████████████▍           | 192/250 [18:26<05:30,  5.70s/it][A

Training loss: 0.70



 77%|██████████████████████████████████████▌           | 193/250 [18:32<05:26,  5.73s/it][A

Training loss: 0.72



 78%|██████████████████████████████████████▊           | 194/250 [18:38<05:23,  5.77s/it][A

Training loss: 0.69



 78%|███████████████████████████████████████           | 195/250 [18:43<05:14,  5.72s/it][A

Training loss: 0.75



 78%|███████████████████████████████████████▏          | 196/250 [18:49<05:07,  5.69s/it][A

Training loss: 0.69



 79%|███████████████████████████████████████▍          | 197/250 [18:55<05:02,  5.71s/it][A

Training loss: 0.67



 79%|███████████████████████████████████████▌          | 198/250 [19:00<04:55,  5.68s/it][A

Training loss: 0.70



 80%|███████████████████████████████████████▊          | 199/250 [19:06<04:52,  5.73s/it][A

Training loss: 0.70



 80%|████████████████████████████████████████          | 200/250 [19:12<04:43,  5.67s/it][A

Training loss: 0.70



 80%|████████████████████████████████████████▏         | 201/250 [19:17<04:38,  5.68s/it][A

Training loss: 0.68



 81%|████████████████████████████████████████▍         | 202/250 [19:23<04:33,  5.70s/it][A

Training loss: 0.70



 81%|████████████████████████████████████████▌         | 203/250 [19:29<04:33,  5.81s/it][A

Training loss: 0.69



 82%|████████████████████████████████████████▊         | 204/250 [19:35<04:26,  5.79s/it][A

Training loss: 0.70



 82%|█████████████████████████████████████████         | 205/250 [19:41<04:25,  5.89s/it][A

Training loss: 0.70



 82%|█████████████████████████████████████████▏        | 206/250 [19:47<04:19,  5.91s/it][A

Training loss: 0.70



 83%|█████████████████████████████████████████▍        | 207/250 [19:53<04:11,  5.84s/it][A

Training loss: 0.70



 83%|█████████████████████████████████████████▌        | 208/250 [19:58<04:01,  5.74s/it][A

Training loss: 0.69



 84%|█████████████████████████████████████████▊        | 209/250 [20:04<03:54,  5.71s/it][A

Training loss: 0.68



 84%|██████████████████████████████████████████        | 210/250 [20:10<03:49,  5.73s/it][A

Training loss: 0.72



 84%|██████████████████████████████████████████▏       | 211/250 [20:15<03:42,  5.71s/it][A

Training loss: 0.66



 85%|██████████████████████████████████████████▍       | 212/250 [20:21<03:38,  5.75s/it][A

Training loss: 0.73



 85%|██████████████████████████████████████████▌       | 213/250 [20:27<03:36,  5.84s/it][A

Training loss: 0.68



 86%|██████████████████████████████████████████▊       | 214/250 [20:33<03:29,  5.81s/it][A

Training loss: 0.67



 86%|███████████████████████████████████████████       | 215/250 [20:39<03:22,  5.78s/it][A

Training loss: 0.69



 86%|███████████████████████████████████████████▏      | 216/250 [20:44<03:14,  5.72s/it][A

Training loss: 0.72



 87%|███████████████████████████████████████████▍      | 217/250 [20:50<03:06,  5.66s/it][A

Training loss: 0.70



 87%|███████████████████████████████████████████▌      | 218/250 [20:55<03:01,  5.68s/it][A

Training loss: 0.68



 88%|███████████████████████████████████████████▊      | 219/250 [21:01<02:55,  5.67s/it][A

Training loss: 0.70



 88%|████████████████████████████████████████████      | 220/250 [21:07<02:47,  5.60s/it][A

Training loss: 0.69



 88%|████████████████████████████████████████████▏     | 221/250 [21:12<02:43,  5.65s/it][A

Training loss: 0.70



 89%|████████████████████████████████████████████▍     | 222/250 [21:18<02:39,  5.68s/it][A

Training loss: 0.70



 89%|████████████████████████████████████████████▌     | 223/250 [21:24<02:33,  5.69s/it][A

Training loss: 0.71



 90%|████████████████████████████████████████████▊     | 224/250 [21:30<02:28,  5.71s/it][A

Training loss: 0.67



 90%|█████████████████████████████████████████████     | 225/250 [21:35<02:24,  5.77s/it][A

Training loss: 0.69



 90%|█████████████████████████████████████████████▏    | 226/250 [21:41<02:18,  5.77s/it][A

Training loss: 0.67



 91%|█████████████████████████████████████████████▍    | 227/250 [21:47<02:13,  5.79s/it][A

Training loss: 0.68



 91%|█████████████████████████████████████████████▌    | 228/250 [21:53<02:06,  5.74s/it][A

Training loss: 0.70



 92%|█████████████████████████████████████████████▊    | 229/250 [21:58<02:00,  5.73s/it][A

Training loss: 0.70



 92%|██████████████████████████████████████████████    | 230/250 [22:04<01:54,  5.70s/it][A

Training loss: 0.67



 92%|██████████████████████████████████████████████▏   | 231/250 [22:10<01:47,  5.67s/it][A

Training loss: 0.70



 93%|██████████████████████████████████████████████▍   | 232/250 [22:15<01:42,  5.69s/it][A

Training loss: 0.70



 93%|██████████████████████████████████████████████▌   | 233/250 [22:21<01:36,  5.65s/it][A

Training loss: 0.70



 94%|██████████████████████████████████████████████▊   | 234/250 [22:26<01:30,  5.63s/it][A

Training loss: 0.68



 94%|███████████████████████████████████████████████   | 235/250 [22:32<01:24,  5.63s/it][A

Training loss: 0.67



 94%|███████████████████████████████████████████████▏  | 236/250 [22:38<01:18,  5.64s/it][A

Training loss: 0.69



 95%|███████████████████████████████████████████████▍  | 237/250 [22:44<01:13,  5.68s/it][A

Training loss: 0.72



 95%|███████████████████████████████████████████████▌  | 238/250 [22:49<01:07,  5.59s/it][A

Training loss: 0.71



 96%|███████████████████████████████████████████████▊  | 239/250 [22:55<01:01,  5.63s/it][A

Training loss: 0.68



 96%|████████████████████████████████████████████████  | 240/250 [23:00<00:55,  5.58s/it][A

Training loss: 0.70



 96%|████████████████████████████████████████████████▏ | 241/250 [23:06<00:50,  5.62s/it][A

Training loss: 0.69



 97%|████████████████████████████████████████████████▍ | 242/250 [23:11<00:44,  5.57s/it][A

Training loss: 0.70



 97%|████████████████████████████████████████████████▌ | 243/250 [23:17<00:39,  5.64s/it][A

Training loss: 0.67



 98%|████████████████████████████████████████████████▊ | 244/250 [23:23<00:34,  5.69s/it][A

Training loss: 0.70



 98%|█████████████████████████████████████████████████ | 245/250 [23:29<00:28,  5.72s/it][A

Training loss: 0.70



 98%|█████████████████████████████████████████████████▏| 246/250 [23:35<00:23,  5.76s/it][A

Training loss: 0.69



 99%|█████████████████████████████████████████████████▍| 247/250 [23:40<00:17,  5.71s/it][A

Training loss: 0.71



 99%|█████████████████████████████████████████████████▌| 248/250 [23:46<00:11,  5.66s/it][A

Training loss: 0.72



100%|█████████████████████████████████████████████████▊| 249/250 [23:51<00:05,  5.67s/it][A

Training loss: 0.71



100%|██████████████████████████████████████████████████| 250/250 [23:57<00:00,  5.67s/it][A

In [14]:
model.eval()

# test on one batch
batch = next(iter(train_dataloader))

labels = batch["labels"]
# del batch["labels"]

batch = {k: torch.transpose(torch.stack(default_convert(v)), 0, 1) for k, v in batch.items()}
batch = {k: v.to(mps_device) for k, v in batch.items()}

outputs = model(**batch)
print(f"outputs = {outputs}")

for index, item in enumerate(output.logits):
    softmax = torch.nn.Softmax()
    print(f"outputs = {softmax(item)}")
    print(f"label = {labels[index]}")

# labels.to(mps_device)
# mps_labels = torch.as_tensor(labels, device=mps_device)

# loss = cross_entropy_loss(output.logits, mps_labels)
# loss.backward()

outputs = SequenceClassifierOutput(loss=None, logits=tensor([[0.1740, 0.1375],
        [0.1741, 0.1376],
        [0.1743, 0.1378],
        [0.1740, 0.1377],
        [0.1741, 0.1377],
        [0.1743, 0.1377],
        [0.1741, 0.1377],
        [0.1743, 0.1379]], device='mps:0', grad_fn=<MpsLinearBackward0>), hidden_states=None, attentions=None)
outputs = tensor([0.4522, 0.5478], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 0
outputs = tensor([0.4769, 0.5231], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 1
outputs = tensor([0.3628, 0.6372], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 0
outputs = tensor([0.4993, 0.5007], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 0
outputs = tensor([0.5464, 0.4536], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 0
outputs = tensor([0.4529, 0.5471], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 1
outputs = tensor([0.5302, 0.4698], device='mps:0', grad_fn=<SoftmaxBackward0>)
label = 1
outputs = tensor([0.4415, 0.558

  print(f"outputs = {softmax(item)}")


In [12]:
from datasets import load_metric


metric = load_metric('accuracy')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [14]:
from datasets import load_metric


validation_dataloader = DataLoader(ds['validation'], shuffle=True, batch_size=8)

metric = load_metric("accuracy")
model.eval()
for batch in validation_dataloader:
    labels = batch["labels"]
    mps_labels = torch.as_tensor(labels, device=mps_device)
    del batch["labels"]

    batch = {k: torch.transpose(torch.stack(default_convert(v)), 0, 1) for k, v in batch.items()}
    batch = {k: v.to(mps_device) for k, v in batch.items()}
    
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    softmax = torch.nn.Softmax()
    # for index, item in enumerate(logits):
        # print(f"probabilities = {softmax(item)}")
        # print(f"label = {labels[index]}")
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=mps_labels)

metric.compute()

{'accuracy': 0.5182012847965739}

In [21]:
from bertviz import head_view


inputs = tokenizer.encode("The cat sat on the mat", return_tensors='pt')
inputs = inputs.to(mps_device)
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0])
head_view(attention, tokens)

<IPython.core.display.Javascript object>