In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os

splits = {'train': 'sberquad/train-00000-of-00001.parquet', 'validation': 'sberquad/validation-00000-of-00001.parquet', 'test': 'sberquad/test-00000-of-00001.parquet'}

df = pd.read_parquet("hf://datasets/kuznetsoffandrey/sberquad/" + splits["train"])
df['answer'] = df['answers'].apply(lambda x: x['text'][0])
df['answer_start'] = df['answers'].apply(lambda x: x['answer_start'][0])
df['answer_end'] = df['answer'].apply(lambda x: len(x)) + df['answer_start']
df = df[df['answer_start'] != -1]
df = df.reset_index(drop=True)
df = df[['context', 'question', 'answer_start', 'answer_end']]
df['has_answer'] = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Генерируем примеры когда ответ не найден
def generate_negative_examples(df):
    shuffled_questions = df['question'].sample(frac=1).reset_index(drop=True) 

    shuffled_question_df = pd.DataFrame({
        'context': df['context'],
        'question': shuffled_questions
    })[:len(df) // 3]

    # Cгенерируем примеры без контекста

    without_context = df['question'].sample(frac=1).reset_index(drop=True)
    without_context_df = pd.DataFrame({
        'context': [''] * len(without_context),
        'question': without_context
    })[:len(df) // 50]

    negative_examples = pd.concat([shuffled_question_df, without_context_df], ignore_index=True)

    negative_examples['answer_start'] = -1
    negative_examples['answer_end'] = -1
    negative_examples['has_answer'] = 0

    df = pd.concat([df, negative_examples], ignore_index=True).sample(frac=1).reset_index(drop=True)

    del shuffled_questions, shuffled_question_df, without_context, without_context_df, negative_examples

    return df

df = generate_negative_examples(df)

In [None]:
class QAModel(nn.Module):
    def __init__(self, transformer_model_name="DeepPavlov/rubert-base-cased"):
        super(QAModel, self).__init__()

        self.transformer = AutoModel.from_pretrained(transformer_model_name)

        hidden_size = self.transformer.config.hidden_size
        self.start_vector = nn.Linear(hidden_size, 1)
        self.end_vector = nn.Linear(hidden_size, 1)

        self.classifier = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        transformer_output = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = transformer_output.last_hidden_state

        # Воспользуемся токеном начала последовательности для классификации
        has_answer = torch.sigmoid(self.classifier(hidden_states[:, 0, :])).squeeze(-1)

        start_logits = self.start_vector(hidden_states).squeeze(-1)

        start_pred = torch.argmax(start_logits, dim=-1)

        # Маскируем на всякий случай токены которые находятся до start_pred
        mask = torch.arange(hidden_states.size(1), device=device)[None, :] >= start_pred[:, None]
        end_hidden_states = hidden_states * mask[:, :, None]

        end_logits = self.end_vector(end_hidden_states).squeeze(-1)

        return start_logits, end_logits, has_answer

model = QAModel().to(device)
model = torch.nn.DataParallel(model, device_ids = [0,1]).to(device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
X_train = tokenizer(df['question'].tolist(), 
                    df['context'].tolist(), 
                    return_tensors="pt", 
                    return_offsets_mapping=True, 
                    padding='max_length',
                    truncation=True,
                    max_length=512
                   )
#Определим токены начала и конца ответов
y_train = []
for i, offset in enumerate(X_train['offset_mapping']):
    if df['has_answer'][i] == 0:
        y_train.append([-1, -1])
        continue

    borders = []
    count_0 = 0
    for j, (start, end) in enumerate(offset):
        if (start, end) == (0, 0):
            count_0 += 1
            if count_0 > 2:
                break
        if start <= df['answer_start'][i] <= end:
            if len(borders) > 0:
                borders.pop()
            if count_0 == 2: # мы находимся в контексте
                borders.append(j)

        if start <= df['answer_end'][i] <= end:
            if count_0 == 2:
                borders.append(j) # мы находимся в контексте
                break
    
    while len(borders) < 2:
        borders.append(j - 1)

    y_train.append(borders)

y_train = torch.tensor(y_train)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, encodings, labels, has_answer_labels):
        self.encodings = encodings
        self.labels = labels
        self.has_answer_labels = has_answer_labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = [torch.tensor(self.labels[idx][0]), 
                          torch.tensor(self.labels[idx][1])]
        item['has_answer'] = torch.tensor(self.has_answer_labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

dataset = CustomDataset(dict(X_train), y_train, df['has_answer'].tolist())
loader = DataLoader(dataset, batch_size = 16, shuffle = True)

In [None]:
save_dir = 'model_checkpoints'
os.makedirs(save_dir, exist_ok=True)

num_epochs = 10
optimizer = optim.Adam(model.parameters(), lr=1e-5)
total_steps  = len(loader) * num_epochs
criterion = torch.nn.CrossEntropyLoss()
criterion_has_answer = torch.nn.BCELoss()
sheduler = get_linear_schedule_with_warmup(optimizer, 
                                          num_warmup_steps=total_steps*0.001,
                                         num_training_steps=total_steps)

for epoch in range(num_epochs):
    epoch_loss = 0.0

    with tqdm(loader, unit='batch') as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")

        for batch in tepoch:

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels_start = torch.tensor(batch['labels'][0]).to(device)
            labels_end = torch.tensor(batch['labels'][1]).to(device)
            has_answer = batch['has_answer'].to(device)

            start_logits, end_logits, has_answer_prob = model(input_ids=batch['input_ids'], 
                                                            attention_mask=batch['attention_mask'])

            loss_classification = criterion_has_answer(has_answer_prob, has_answer.float())

            # Считаем loss классификации начала и конца только по тем объектам где есть ответ
            if has_answer.sum() > 0:
                mask = has_answer.bool()
                start_loss = criterion(start_logits[mask], labels_start[mask])
                end_loss = criterion(end_logits[mask], labels_end[mask])
                loss = loss_classification + start_loss + end_loss
            else:
                loss = loss_classification

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            sheduler.step()

            epoch_loss += loss.item()
            tepoch.set_postfix(loss=loss.item())
      
    # Сохранение модели и оптимизатора после каждой эпохи
    if (epoch + 1) % 5 == 0:
        checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': sheduler.state_dict(),
            'loss': epoch_loss / len(loader)
      }, checkpoint_path)
        print(f"Model saved to {checkpoint_path}")
    
    print(f"Epoch {epoch+1} finished with loss: {epoch_loss / len(loader)}")