In [1]:
from datasets import load_from_disk, load_dataset
from transformers import BertModel, AutoTokenizer, DataCollatorWithPadding, BertForSequenceClassification, Trainer, EarlyStoppingCallback, TrainingArguments
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from datasets import load_metric
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score


In [2]:
sst3 = load_from_disk("../data/sst5")
sst3

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text', 'ternary_label'],
        num_rows: 8544
    })
    test: Dataset({
        features: ['text', 'label', 'label_text', 'ternary_label'],
        num_rows: 2210
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text', 'ternary_label'],
        num_rows: 1101
    })
})

In [3]:
checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [4]:
max_sequence_length = 128
batch_size = 32
eval_steps = 200
learning_rate=2e-05
num_train_epochs=4
output_dir = "../output/"
model_dir = "../models/"
early_stopping_patience = 10

In [5]:
def tokenize_function(example):
    return tokenizer(example["text"],  truncation=True, padding="max_length", max_length=max_sequence_length )


tokenized_datasets = sst3.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Loading cached processed dataset at ../data/sst5/train/cache-3dccc6f510391e93.arrow


  0%|          | 0/3 [00:00<?, ?ba/s]

Loading cached processed dataset at ../data/sst5/validation/cache-9da11bc2f7bd4b20.arrow


In [6]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'label_text', 'ternary_label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 8544
    })
    test: Dataset({
        features: ['text', 'label', 'label_text', 'ternary_label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2210
    })
    validation: Dataset({
        features: ['text', 'label', 'label_text', 'ternary_label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1101
    })
})

In [7]:
tokenized_datasets = tokenized_datasets.remove_columns(["text", "label_text", "label"])
tokenized_datasets = tokenized_datasets.rename_column("ternary_label", "labels")
tokenized_datasets.set_format("torch")
tokenized_datasets["train"].column_names

['labels', 'input_ids', 'token_type_ids', 'attention_mask']

In [8]:
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=3)
model

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [9]:
def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average="macro")
    precision = precision_score(y_true=labels, y_pred=pred, average="macro")
    f1 = f1_score(y_true=labels, y_pred=pred,average="macro")    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

In [10]:
training_args = TrainingArguments(
   output_dir+"bert-base-cased-sst3",
   evaluation_strategy ='steps',
   eval_steps = eval_steps , # Evaluation and Save happens every eval_steps steps
   save_total_limit = 1, # Only last  model is saved. Older ones are deleted.
   learning_rate=learning_rate,
   per_device_train_batch_size=batch_size,
   per_device_eval_batch_size=batch_size,
   num_train_epochs=num_train_epochs,
   metric_for_best_model = 'f1',
   load_best_model_at_end=True)

In [11]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)],
    compute_metrics=compute_metrics,
)

In [12]:
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Runtime,Samples Per Second
200,No log,0.727371,0.703906,0.620594,0.606937,0.575203,15.498,71.042
400,No log,0.707844,0.708447,0.645321,0.641031,0.637643,15.457,71.23
600,0.644800,0.747473,0.730245,0.663804,0.660455,0.656583,15.4705,71.167
800,0.644800,0.787229,0.710263,0.65161,0.648087,0.646314,15.4567,71.231
1000,0.339400,0.87026,0.731153,0.676442,0.674402,0.673999,15.464,71.198


TrainOutput(global_step=1068, training_loss=0.47937213704827125, metrics={'train_runtime': 1418.0168, 'train_samples_per_second': 0.753, 'total_flos': 2842898457526272, 'epoch': 4.0})

In [13]:
model.save_pretrained(model_dir+"bert-base-cased-sst3")

In [14]:
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [15]:
test_dataloader = DataLoader(
    tokenized_datasets["test"], batch_size=batch_size, collate_fn=data_collator
)

In [16]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [17]:
test_accuracy = load_metric("accuracy",)
test_f1 = load_metric("f1")
test_progress_bar = tqdm(range(len(test_dataloader)))

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    labels = batch['labels']
    predictions = torch.argmax(logits, dim=-1)
    
    test_accuracy.add_batch(predictions=predictions, references=labels)
    test_f1.add_batch(predictions=predictions, references=labels)
    test_progress_bar.update(1)

print(f"Test Accuracy:{test_accuracy.compute()}.Test F1:{test_f1.compute(average='macro')}")

  0%|          | 0/70 [00:00<?, ?it/s]

Test Accuracy:{'accuracy': 0.7303167420814479}.Test F1:{'f1': 0.6608709697464324}
