# Fine-tune BERT base model

In [1]:
import os
import sys
import json
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from enum import Enum
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer, BertModel, BatchEncoding

sys.path.append("../")
from lib.utils import get_current_date
from lib.utils.constants import Subtask, Track, PreprocessTextLevel, PoolingStrategy, DatasetType
from lib.utils.models import sequential_fully_connected
from lib.data.loading import load_train_dev_test_df, build_data_loader
from lib.data.tokenizer import get_tokenizer
from lib.models import get_model
from lib.training.loss import get_loss_fn
from lib.training.metric import get_metric
from lib.training.loops import training_loop, make_predictions

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
CONFIG_FILE = os.path.relpath("../config.json")

CONFIG = {}
with open(CONFIG_FILE) as f:
    CONFIG = json.load(f)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")

Using device: cpu


In [3]:
df_train, df_dev, df_test = load_train_dev_test_df(
    task=Subtask(CONFIG["task"]),
    track=Track(CONFIG["track"]),
    data_dir="../data/original_data",
    label_column=CONFIG["data"]["label_column"],
    test_size=CONFIG["data"]["test_size"],
    preprocess_text_level=PreprocessTextLevel(
        CONFIG["data"]["preprocess_text_level"]
    ),
)

print(f"df_train.shape: {df_train.shape}")
print(f"df_dev.shape: {df_dev.shape}")
print(f"df_test.shape: {df_test.shape}")

Loading train data...
Train/dev split... (df_train.shape: (119757, 5))
Loading test data...
Cleaning texts with preprocess level `PreprocessTextLevel.LIGHT`...
df_train.shape: (95805, 5)
df_dev.shape: (23952, 5)
df_test.shape: (5000, 5)


In [4]:
max_seq_len = CONFIG["data"]["max_len"]
tokenizer = get_tokenizer(**CONFIG["tokenizer"])

In [5]:
max_seq_len

128

## Which layer should we use for classification?

* Pooled ouput of the [CLS] token
* One single hidden layer 0-max_number_of_layers
* First 4 layers + concat
* First 4 layers + mean
* First 4 layers + max
* Last 4 layers + concat
* Last 4 layers + mean
* Last 4 layers + max
* All layers + concat

In [6]:
# Clean cuda memroy
torch.cuda.empty_cache()

In [6]:
train_dataloader = build_data_loader(
    df_train[:10],
    tokenizer,
    max_len=CONFIG["data"]["max_len"],
    batch_size=CONFIG["data"]["batch_size"],
    label_column=CONFIG["data"]["label_column"],
    shuffle=True,
)

In [17]:
from transformers import BertModel

bert = BertModel.from_pretrained(
    CONFIG["model_config"]["pretrained_model_name"],
    return_dict=False,
    output_hidden_states=True,
)
bert_num_layers = len(bert.encoder.layer)

In [None]:
for p in bert.named_parameters():
    print(p[0])
    # print()

In [18]:
bert

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [27]:
class BERTWithLayerSelection(nn.Module):
    def __init__(
        self,
        pretrained_model_name: str,
        out_size: int = 1,
        dropout_p: float = 0.5,
        selected_layers: [int] = [-1],
        fc: [int] = [],
        out_activation: str | None = None,
    ):
        super().__init__()
        self.selected_layers = selected_layers

        self.bert = BertModel.from_pretrained(
            pretrained_model_name, return_dict=False, output_hidden_states=True
        )
        # self.drop_bert = nn.Dropout(dropout_p)

        input_size = len(selected_layers) * self.bert.config.hidden_size
        self.out = sequential_fully_connected(input_size, out_size, fc, dropout_p)

        self.out_activation = None
        if out_activation == "sigmoid":
            self.out_activation = nn.Sigmoid()

        self.freeze_transformer_layer()

    def forward(self, input_ids, attention_mask):
        bert_outputs = self.bert(input_ids, attention_mask)
        hidden_states = bert_outputs[2]
        bert_cls_features = torch.cat(
            [hidden_states[i][:, 0, :] for i in self.selected_layers],
            dim=1,
        )
        # output = self.drop_bert(pooled_output)
        output = self.out(bert_cls_features)

        if self.out_activation is not None:
            output = self.out_activation(output)

        return output

    def freeze_transformer_layer(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_transformer_layer(self):
        # BERT used only for feature extraction
        pass

    def get_predictions_from_outputs(self, outputs):
        if self.out_activation is None:
            return outputs.flatten().tolist()
        else:
            return torch.round(outputs).flatten().tolist()


In [28]:
bert_layers_to_use = [-1, -2, -3, -4]
CONFIG["model_config"]["selected_layers"] = bert_layers_to_use
bert_with_layer_selection = BERTWithLayerSelection(**CONFIG["model_config"]).to(DEVICE)

In [29]:
# Bert with layer selection

for idx, batch in enumerate(train_dataloader):
    ids = batch["id"]
    input_ids = batch["input_ids"]  # .to(device)
    attention_mask = batch["attention_mask"]  # .to(device)
    targets = batch["target"].to(DEVICE)

    outputs = bert_with_layer_selection(
        input_ids=input_ids, attention_mask=attention_mask
    )

    predictions = bert_with_layer_selection.get_predictions_from_outputs(outputs)
    true = targets.flatten().tolist()

    print(f"predictions = {predictions}")
    print(f"true = {true}")

    # expected_input_size = len(bert_layers_to_use) * bert.config.hidden_size
    # print(f"expected_input_size: {expected_input_size}")

    # embs = bert(input_ids, attention_mask)
    # print(f"embs[0].shape: {embs[0].shape}")
    # print(f"embs[1].shape: {embs[1].shape}")

    # hidden_states = embs[2]
    # selected_layers = [hidden_states[i] for i in bert_layers_to_use]

    # # selected_layers = torch.cat(selected_layers, dim=2)
    # # print(f"selected_layers.shape: {selected_layers.shape}")

    # # mean_layer = torch.mean(selected_layers, dim=1)
    # # print(f"mean_layer.shape: {mean_layer.shape}")

    # cls_features = [layer[:, 0, :] for layer in selected_layers]
    # cls_features = torch.cat(cls_features, dim=1)
    # print(f"cls_features.shape: {cls_features.shape}")


predictions = [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]
true = [0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0]
predictions = [1.0, 1.0]
true = [0.0, 1.0]


In [8]:
selected_layers_options = {
    # "1-1_layers": [1],
    # "1-2_layers": [1, 2],
    # "1-3_layers": [1, 2, 3],
    # "1-4_layers": [1, 2, 3, 4],
    # "1-5_layers": [1, 2, 3, 4, 5],
    # "1-6_layers": [1, 2, 3, 4, 5, 6],
    # "1-7_layers": [1, 2, 3, 4, 5, 6, 7],
    # "1-8_layers": [1, 2, 3, 4, 5, 6, 7, 8],
    # "1-9_layers": [1, 2, 3, 4, 5, 6, 7, 8, 9],
    # "1-10_layers": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    # "1-11_layers": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
    # "1-12_layers": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    "first_4_layers": [1, 2, 3, 4],
    "last_4_layers": [-1, -2, -3, -4],
}

for selection_name, selected_layers in selected_layers_options.items():
    print(f"##### Layer selection: {selection_name} #####\n")

    CONFIG["model_config"]["selected_layers"] = selected_layers

    if CONFIG["track"] is None:
        results_path = (
            f"../runs/{get_current_date()}-{CONFIG['task']}-{CONFIG['model']}_{selection_name}"
        )
    else:
        results_path = (
            f"../runs/{get_current_date()}-"
            f"{CONFIG['task']}-{CONFIG['track']}-{CONFIG['model']}_{selection_name}"
        )

    print(f"Will save results to: {results_path}\n")
    os.mkdir(results_path)

    with open(results_path + "/config.json", "w") as f:
        json.dump(CONFIG, f, indent=4)

    train_dataloader = build_data_loader(
        df_train[:100],
        tokenizer,
        max_len=CONFIG["data"]["max_len"],
        batch_size=CONFIG["data"]["batch_size"],
        label_column=CONFIG["data"]["label_column"],
        shuffle=True,
    )
    dev_dataloader = build_data_loader(
        df_dev[:100],
        tokenizer,
        max_len=CONFIG["data"]["max_len"],
        batch_size=CONFIG["data"]["batch_size"],
        label_column=CONFIG["data"]["label_column"],
    )
    test_dataloader = build_data_loader(
        df_test[:100],
        tokenizer,
        max_len=CONFIG["data"]["max_len"],
        batch_size=CONFIG["data"]["batch_size"],
        label_column=CONFIG["data"]["label_column"],
        has_targets=False if CONFIG["data"]["test_size"] is None else True,
    )

    num_epochs = CONFIG["training"]["num_epochs"]
    model = get_model(CONFIG["model"], CONFIG["model_config"]).to(DEVICE)
    loss_fn = get_loss_fn(CONFIG["training"]["loss"], DEVICE)
    optimizer_config = CONFIG["training"]["optimizer"]
    scheduler_config = CONFIG["training"]["scheduler"]
    metric_fn, is_better_metric_fn = get_metric(CONFIG["training"]["metric"])
    num_epochs_before_finetune = CONFIG["training"]["num_epochs_before_finetune"]

    best_model = training_loop(
        model,
        num_epochs,
        train_dataloader,
        dev_dataloader,
        loss_fn,
        optimizer_config,
        scheduler_config,
        DEVICE,
        metric_fn,
        is_better_metric_fn,
        results_path,
        num_epochs_before_finetune,
    )

    make_predictions(
        best_model,
        test_dataloader,
        DEVICE,
        results_path,
        label_column=CONFIG["data"]["label_column"],
        file_format=CONFIG["submission_format"],
    )

    print("-" * 50)
    print()

##### Layer selection: first_4_layers #####

Will save results to: ../runs/28-11-2023_09:34:44-SubtaskA-monolingual-bert_with_layer_selection_first_4_layers

Epoch 1/1
Freeze transformeer
--------------------
Batch=[1/13]; Loss=[0.64347]; Acc. Metric=0.875
Train Loss: 0.69481; Train Metric: 0.53000


100%|██████████| 13/13 [00:03<00:00,  4.27it/s]


Validation Loss: 0.68766; Validation Metric: 0.58000


100%|██████████| 13/13 [00:02<00:00,  4.60it/s]


--------------------------------------------------

##### Layer selection: last_4_layers #####

Will save results to: ../runs/28-11-2023_09:34:57-SubtaskA-monolingual-bert_with_layer_selection_last_4_layers

Epoch 1/1
Freeze transformeer
--------------------
Batch=[1/13]; Loss=[0.72704]; Acc. Metric=0.375
Train Loss: 0.71874; Train Metric: 0.42000


100%|██████████| 13/13 [00:03<00:00,  4.33it/s]


Validation Loss: 0.70619; Validation Metric: 0.44000


100%|██████████| 13/13 [00:02<00:00,  4.55it/s]

--------------------------------------------------






In [9]:
!python ../scores_and_plots.py --results-dir "../runs/28-11-2023_09:34:44-SubtaskA-monolingual-bert_with_layer_selection_first_4_layers"

Results on validation
Accuracy: 58.00%
  _warn_prf(average, modifier, msg_start, len(result))
Precision: 0.00%
Recall: 0.00%
F1: 0.00%
--------------------
Results on test
Accuracy: 0.00%
  _warn_prf(average, modifier, msg_start, len(result))
Precision: 0.00%
Recall: 0.00%
F1: 0.00%
--------------------


In [10]:
!python ../scores_and_plots.py --results-dir "../runs/28-11-2023_09:34:57-SubtaskA-monolingual-bert_with_layer_selection_last_4_layers"

Results on validation
Accuracy: 44.00%
Precision: 37.50%
Recall: 50.00%
F1: 42.86%
--------------------
Results on test
Accuracy: 76.00%
Precision: 100.00%
Recall: 76.00%
F1: 86.36%
--------------------
