# Practical semi-automatic classification with pre-trained BERT

In this notebook I download a pre-trained BERT model from hugging face, manually customize it and fine-tune it.

I create a custom classificator model in pure PyTorch and fine-tune it using PyTorch low-level primitives only.

Another notebook is available, using high-level HuggingFace classes.

In [7]:
import transformers
import torch
from torch import nn


class BERTClassification(nn.Module):
    def __init__ (self):
        # for reference on bert outputs for classification see:
        # https://huggingface.co/transformers/v3.0.2/model_doc/bert.html#bertmodel
        # https://stackoverflow.com/questions/61331991/bert-pooled-output-is-different-from-first-vector-of-sequence-output
        super(BERTClassification, self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-cased')
        self.bert_dropout = nn.Dropout(p=0.4)
        self.classifier = nn.Linear(768, 1)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        sequence_output, pooled_output = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            token_type_ids=token_type_ids, 
            return_dict=False
        )
        bert_with_dropout = self.bert_dropout(pooled_output)
        output = self.classifier(bert_with_dropout)
        
        return output
    
model = BERTClassification()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
import pandas as pd


essays = pd.read_csv("./data/essays.csv")

essays.loc[essays['cEXT'] == 'n', 'cEXT'] = 0
essays.loc[essays['cEXT'] == 'y', 'cEXT'] = 1

essays.loc[essays['cNEU'] == 'n', 'cNEU'] = 0
essays.loc[essays['cNEU'] == 'y', 'cNEU'] = 1

essays.loc[essays['cAGR'] == 'n', 'cAGR'] = 0
essays.loc[essays['cAGR'] == 'y', 'cAGR'] = 1

essays.loc[essays['cCON'] == 'n', 'cCON'] = 0
essays.loc[essays['cCON'] == 'y', 'cCON'] = 1

essays.loc[essays['cOPN'] == 'n', 'cOPN'] = 0
essays.loc[essays['cOPN'] == 'y', 'cOPN'] = 1

essays.astype({'cEXT': 'int32', 'cNEU': 'int32', 'cAGR': 'int32', 'cCON': 'int32', 'cOPN': 'int32'}).dtypes

essays

Unnamed: 0,#AUTHID,TEXT,cEXT,cNEU,cAGR,cCON,cOPN
0,1997_504851.txt,"Well, right now I just woke up from a mid-day ...",0,1,1,0,1
1,1997_605191.txt,"Well, here we go with the stream of consciousn...",0,0,1,0,0
2,1997_687252.txt,An open keyboard and buttons to push. The thin...,0,1,0,1,1
3,1997_568848.txt,I can't believe it! It's really happening! M...,1,0,1,1,0
4,1997_688160.txt,"Well, here I go with the good old stream of co...",1,0,1,0,1
...,...,...,...,...,...,...,...
2462,2004_493.txt,I'm home. wanted to go to bed but remembe...,0,1,0,1,0
2463,2004_494.txt,Stream of consiousnesssskdj. How do you s...,1,1,0,0,1
2464,2004_497.txt,"It is Wednesday, December 8th and a lot has be...",0,0,1,0,0
2465,2004_498.txt,"Man this week has been hellish. Anyways, now i...",0,1,0,0,1


In [5]:
# I'm running this on Apple Silicon. Activate Metal "mps" device, if available:
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

else:
    mps_device = torch.device("mps")


torch.device("mps")
model.to(mps_device)

BERTClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [6]:
from tqdm.auto import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split, default_convert
from datasets import Dataset
from transformers import get_scheduler
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", num_labels=2)

def tokenize_function(examples):
    return tokenizer(examples["TEXT"], padding="max_length", truncation=True)  # , return_tensors="pt")

essays_dataset = Dataset.from_pandas(essays)
tokenized_dataset = essays_dataset.map(tokenize_function, batched=True, batch_size=32)

tokenized_dataset = tokenized_dataset.rename_column("TEXT", "text")
tokenized_dataset = tokenized_dataset.rename_column("cNEU", "labels")
tokenized_dataset = tokenized_dataset.remove_columns(['#AUTHID', 'text', 'cEXT', 'cAGR', 'cCON', 'cOPN'])

# train_dataset = train_dataset.shuffle(seed=42).select(range(1000))
# test_dataset = test_dataset.shuffle(seed=42).select(range(1000))
train_dataset, test_dataset = random_split(tokenized_dataset, [2000, len(tokenized_dataset) - 2000])

train_dataloader = DataLoader(tokenized_dataset, shuffle=True, batch_size=32)

# parameters
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)

# optimizer, scheduler, loss, etc.
optimizer = AdamW(model.parameters(), lr=5e-5)
cross_entropy_loss = nn.CrossEntropyLoss()

lr_scheduler = get_scheduler(
    name="linear", 
    optimizer=optimizer, 
    num_warmup_steps=0, 
    num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

model.train()

for epoch in range(num_epochs):
    for batch in train_dataloader:
        print(batch.keys())
        labels = batch["labels"]
        del batch["labels"]
        batch = {k: torch.stack(default_convert(v)) for k, v in batch.items()}
        batch = {k: v.to(mps_device) for k, v in batch.items()}
        print(batch.keys())
        outputs = model(**batch)
        loss = cross_entropy_loss(outputs, labels)
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 74.53ba/s]
  0%|                                                                                                                                                                                                            | 0/234 [00:00<?, ?it/s]

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])


ValueError: Expected input batch_size (512) to match target batch_size (32).