# ModernBERT fine tuning
---

In [1]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = '../models/answerdotai--ModernBERT-base'


model = AutoModelForQuestionAnswering.from_pretrained(model_path, attn_implementation="flash_attention_2").to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)

Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertForQuestionAnswering is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)`
Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in ModernBertModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)`


Loading weights:   0%|          | 0/136 [00:00<?, ?it/s]

[1mModernBertForQuestionAnswering LOAD REPORT[0m from: ../models/answerdotai--ModernBERT-base
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


# Dataset preprocessing, tokenization
---

In [None]:
import json
import orjson

path = '/disk-1/drezov/preprocessed_data/nq_original/train/train_nq.jsonl'
data = []
with open(path, 'r', encoding='utf-8') as f:
    stop_f = 100
    data = []
    for line in f:
        _ = orjson.loads(line)
        data.append(_)
        stop_f -= 1

In [2]:
from datasets import load_dataset

raw_datasets = load_dataset("squad")



In [3]:
context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS]To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?[SEP]Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.[SEP]'

In [4]:
inputs = tokenizer(
    raw_datasets["train"][2:6]["question"],
    raw_datasets["train"][2:6]["context"],
    max_length=100,
    truncation="only_second",
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True
)

for ids in inputs["input_ids"]:
    print(tokenizer.decode(ids))

[CLS]The Basilica of the Sacred heart at Notre Dame is beside to which structure?[SEP]Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the[SEP]
[CLS]The Basilica of the Sacred heart at Notre Dame is beside to which structure?[SEP] of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary[SEP]
[CLS]The Basilica of the Sacred heart at Notre Dame is beside to which structure?[SEP] the Main Buil

In [5]:
answers = raw_datasets["train"][2:6]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):
    sample_idx = inputs["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions

([80, 49, 18, 0, 0, 60, 23, 0, 34, 0, 0, 0, 66, 33, 0, 0, 0, 0, 0],
 [82, 51, 20, 0, 0, 66, 29, 0, 40, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0])

In [6]:
idx = 4
sample_idx = inputs["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

Theoretical answer: a Marian place of prayer and reflection, labels give: [CLS]


In [7]:
max_length = 512
stride = 128


def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [8]:
train_dataset = raw_datasets["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)
len(raw_datasets["train"]), len(train_dataset)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

(87599, 87749)

In [9]:
def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [10]:
validation_dataset = raw_datasets["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

(10570, 10619)

In [26]:
sub_set = raw_datasets['validation'].select(range(100))
sub_set.set_format('torch')

sub_set 

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 100
})

In [45]:
small_eval_set = raw_datasets["validation"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

In [46]:
eval_set[0]

{'input_ids': [101,
  5979,
  4279,
  1264,
  2533,
  1103,
  10402,
  1120,
  3198,
  5308,
  1851,
  136,
  102,
  3198,
  5308,
  1851,
  1108,
  1126,
  1237,
  1709,
  1342,
  1106,
  4959,
  1103,
  3628,
  1104,
  1103,
  1305,
  2289,
  1453,
  113,
  4279,
  114,
  1111,
  1103,
  1410,
  1265,
  119,
  1109,
  1237,
  2289,
  3047,
  113,
  10402,
  114,
  3628,
  7068,
  14722,
  2378,
  1103,
  1305,
  2289,
  3047,
  113,
  24743,
  114,
  3628,
  2938,
  13598,
  1572,
  782,
  1275,
  1106,
  7379,
  1147,
  1503,
  3198,
  5308,
  1641,
  119,
  1109,
  1342,
  1108,
  1307,
  1113,
  1428,
  128,
  117,
  1446,
  117,
  1120,
  12388,
  112,
  188,
  3339,
  1107,
  1103,
  1727,
  2948,
  2410,
  3894,
  1120,
  3364,
  10200,
  117,
  1756,
  119,
  1249,
  1142,
  1108,
  1103,
  13163,
  3198,
  5308,
  117,
  1103,
  2074,
  13463,
  1103,
  107,
  5404,
  5453,
  107,
  1114,
  1672,
  2284,
  118,
  12005,
  11751,
  117,
  1112,
  1218,
  1112,
  7818,
  28117,

In [49]:
eval_set.column_names

['input_ids',
 'token_type_ids',
 'attention_mask',
 'offset_mapping',
 'example_id']

In [56]:
sub_eval_set = eval_set.select(range(100))
print(len(sub_eval_set[0]['input_ids']))
sub_eval_set.set_format('torch')
print(type(sub_eval_set['input_ids'][0]))

sub_eval_set[0]

512
<class 'torch.Tensor'>


{'input_ids': tensor([  101,  5979,  4279,  1264,  2533,  1103, 10402,  1120,  3198,  5308,
          1851,   136,   102,  3198,  5308,  1851,  1108,  1126,  1237,  1709,
          1342,  1106,  4959,  1103,  3628,  1104,  1103,  1305,  2289,  1453,
           113,  4279,   114,  1111,  1103,  1410,  1265,   119,  1109,  1237,
          2289,  3047,   113, 10402,   114,  3628,  7068, 14722,  2378,  1103,
          1305,  2289,  3047,   113, 24743,   114,  3628,  2938, 13598,  1572,
           782,  1275,  1106,  7379,  1147,  1503,  3198,  5308,  1641,   119,
          1109,  1342,  1108,  1307,  1113,  1428,   128,   117,  1446,   117,
          1120, 12388,   112,   188,  3339,  1107,  1103,  1727,  2948,  2410,
          3894,  1120,  3364, 10200,   117,  1756,   119,  1249,  1142,  1108,
          1103, 13163,  3198,  5308,   117,  1103,  2074, 13463,  1103,   107,
          5404,  5453,   107,  1114,  1672,  2284,   118, 12005, 11751,   117,
          1112,  1218,  1112,  7818, 28

In [59]:
import torch
from transformers import AutoModelForQuestionAnswering

eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
print(type(eval_set_for_model))
eval_set_for_model.set_format(type="torch")
print(eval_set_for_model.format)
print(eval_set_for_model.column_names)
print(type(eval_set_for_model['input_ids']))
print(type(eval_set_for_model[0]['input_ids']))
print(eval_set_for_model[0]['input_ids'].dim())
print(eval_set_for_model[0]['input_ids'])


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = eval_set_for_model[:]
batch = {k: v.to(device) for k, v in batch.items()}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

<class 'datasets.arrow_dataset.Dataset'>
{'type': 'torch', 'format_kwargs': {}, 'columns': ['input_ids', 'token_type_ids', 'attention_mask'], 'output_all_columns': False}
['input_ids', 'token_type_ids', 'attention_mask']
<class 'datasets.arrow_dataset.Column'>
<class 'torch.Tensor'>
1
tensor([  101,  5979,  4279,  1264,  2533,  1103, 10402,  1120,  3198,  5308,
         1851,   136,   102,  3198,  5308,  1851,  1108,  1126,  1237,  1709,
         1342,  1106,  4959,  1103,  3628,  1104,  1103,  1305,  2289,  1453,
          113,  4279,   114,  1111,  1103,  1410,  1265,   119,  1109,  1237,
         2289,  3047,   113, 10402,   114,  3628,  7068, 14722,  2378,  1103,
         1305,  2289,  3047,   113, 24743,   114,  3628,  2938, 13598,  1572,
          782,  1275,  1106,  7379,  1147,  1503,  3198,  5308,  1641,   119,
         1109,  1342,  1108,  1307,  1113,  1428,   128,   117,  1446,   117,
         1120, 12388,   112,   188,  3339,  1107,  1103,  1727,  2948,  2410,
         389

Loading weights:   0%|          | 0/102 [00:00<?, ?it/s]

In [3]:
from transformers import Trainer, TrainingArguments

In [None]:
config = TrainingArguments(
    
)