### Сетап

In [1]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, get_scheduler, default_data_collator
from accelerate import Accelerator
from tqdm.auto import tqdm
import numpy as np
import evaluate
import collections
import ipywidgets as widgets

In [2]:
DATASET_NAME = "kuznetsoffandrey/sberquad"
MODEL_NAME = "DeepPavlov/rubert-base-cased"
MODEL_SAVE_DIR ='rubert-v3'

### Загрузка датасета и его предобработка

In [3]:
dataset = load_dataset(DATASET_NAME)

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 45328
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 5036
    })
    test: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 23936
    })
})

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [6]:
def preprocess_data(examples, tokenizer, is_test=False, max_length=384, stride=128):
    questions = [q.strip() for q in examples['question']]
    
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    answers = examples["answers"]
    if is_test:
        offset_mapping = inputs["offset_mapping"]
        example_ids = []
    else:
        offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    start_positions = []
    end_positions = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        offset = offset_mapping[i]
        sequence_ids = inputs.sequence_ids(i)

        if is_test:
            example_ids.append(examples["id"][sample_idx])
            inputs["offset_mapping"][i] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]

        if len(answer['answer_start'])==0:
            start_positions.append(0)
            end_positions.append(0)
            continue
        
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)
            
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    if is_test:
        inputs["example_id"] = example_ids
    return inputs

In [7]:
dataset_train = dataset['train'].map(
    preprocess_data,
    batched=True,
    remove_columns=dataset['train'].column_names,
    fn_kwargs = {
        'tokenizer': tokenizer,
    }
)

In [8]:
dataset_train

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 45544
})

In [9]:
dataset_val = dataset['validation'].map(
    preprocess_data,
    batched=True,
    remove_columns=dataset['validation'].column_names,
    fn_kwargs = {
        'tokenizer': tokenizer,
        'is_test': True,
    }
)

Map:   0%|          | 0/5036 [00:00<?, ? examples/s]

In [10]:
dataset_train.set_format("torch")

dataset_val_formatted = dataset_val.remove_columns(["example_id", "offset_mapping"])
dataset_val_formatted.set_format("torch")

In [11]:
dataloader_train = DataLoader(
    dataset_train,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=16
)

dataloader_val = DataLoader(
    dataset_val_formatted,
    collate_fn=default_data_collator,
    batch_size=16
)

### Определение функций для оценки тренировки модели

In [12]:
def format_predictions(start_logits, end_logits, inputs, examples, n_best=20, max_answer_length=30):
    assert n_best <= len(inputs['offset_mapping'][0]), 'n_best cannot be larger than max_length'
    
    example_to_inputs = collections.defaultdict(list)
    for idx, feature in enumerate(inputs):
        example_to_inputs[str(feature["example_id"])].append(idx)
    
    predicted_answers = []
    for example in tqdm(examples):
        example_id = str(example["id"])
        context = example["context"]
        answers = []
        
        for feature_index in example_to_inputs[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]

            offsets = inputs[feature_index]['offset_mapping']
            start_indices = np.argsort(start_logit)[-1:-n_best-1:-1].tolist()
            end_indices = np.argsort(end_logit)[-1 :-n_best-1: -1].tolist()

            for start_index in start_indices:
                for end_index in end_indices:
                    if (end_index < start_index or end_index - start_index + 1 > max_answer_length):
                        continue
                    if (offsets[start_index] is None)^(offsets[end_index] is None):
                        continue
                    
                    if (offsets[start_index] is None)&(offsets[end_index] is None):
                        answers.append(
                            {
                                "text": '',
                                "logit_score": start_logit[start_index] + end_logit[end_index],
                            }
                        )

                    else:
                        answers.append(
                            {
                                "text": context[offsets[start_index][0] : offsets[end_index][1]],
                                "logit_score": start_logit[start_index] + end_logit[end_index],
                            }
                        )

        if len(answers) > 0:
            best_answer = max(answers, key=lambda x:x['logit_score'])
            predicted_answers.append({'id':example_id, 'prediction_text':best_answer['text']})
        else:
            predicted_answers.append({'id':example_id, 'prediction_text':''})

    return predicted_answers

In [13]:
def compute_metrics(start_logits, end_logits, inputs, examples, n_best = 20, max_answer_length=30):

    metric = evaluate.load('squad_v2')
    predicted_answers = format_predictions(start_logits, end_logits, inputs, examples,
                                           n_best=n_best, max_answer_length=max_answer_length)
    for pred in predicted_answers:
        pred['no_answer_probability'] = 1.0 if pred['prediction_text'] == '' else 0.0

    correct_answers = []
    for example in examples:
        input_id = str(example["id"])
        answers = example["answers"]
        correct_answers.append({
            "id": input_id,
            "answers": {
                "text": answers["text"],
                "answer_start": answers["answer_start"]
            }
        })
    return metric.compute(predictions=predicted_answers, references=correct_answers)

### Тренировка модели

In [14]:
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
optimizer = AdamW(model.parameters(), lr=3e-5)

num_epochs = 1
num_training_steps = len(dataloader_train)*num_epochs

scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps = 0,
    num_training_steps = num_training_steps,
)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at DeepPavlov/rubert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
accelerator = Accelerator(mixed_precision="fp16")

model, optimizer, dataloader_train, dataloader_val = accelerator.prepare(
    model, optimizer, dataloader_train, dataloader_val
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [16]:
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_epochs):
    # Тренировка модели
    model.train()
    for step, batch in enumerate(dataloader_train):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
    
    # Оценка на валидационных данных
    model.eval()
    start_logits = []
    end_logits = []
    for batch in tqdm(dataloader_val, desc='Оценка модели'):
        with torch.no_grad():
            outputs = model(**batch)
        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())
    
    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    start_logits = start_logits[: len(dataset_val)]
    end_logits = end_logits[: len(dataset_val)]
    
    metrics = compute_metrics(
        start_logits, end_logits, dataset_val, dataset['validation']
    )
    print(f"Эпоха {epoch+1}:", metrics)
    
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(MODEL_SAVE_DIR,save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(MODEL_SAVE_DIR)

  0%|          | 0/2847 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Оценка модели:   0%|          | 0/317 [00:00<?, ?it/s]

  0%|          | 0/5036 [00:00<?, ?it/s]

Эпоха 1: {'exact': 62.4900714853058, 'f1': 82.10095375768547, 'total': 5036, 'HasAns_exact': 62.4900714853058, 'HasAns_f1': 82.10095375768547, 'HasAns_total': 5036, 'best_exact': 62.4900714853058, 'best_exact_thresh': 0.0, 'best_f1': 82.10095375768547, 'best_f1_thresh': 0.0}


Лучшие метрики были получены на 1ой эпохе:

> Exact match = 62.49 \
> F1-score = 82.10

### Сохранение модели

In [18]:
tokenizer.save_pretrained(MODEL_SAVE_DIR)
model.save_pretrained(MODEL_SAVE_DIR)