In [None]:
!pip install datasets rouge

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, \
    AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, \
    AutoModelForQuestionAnswering, AutoModelForTokenClassification

from sklearn.metrics import accuracy_score, f1_score

from datasets import load_dataset
from rouge import Rouge
from nltk.translate.bleu_score import corpus_bleu

import pandas as pd
import yaml
from tqdm import tqdm

In [None]:
# Образец yaml-конфигурации

"""
device: cpu
model_name: cointegrated/rubert-tiny2
model_type: classification
num_labels: 5
max_length: 512
stride: 128
batch_size: 8
learning_rate: 2e-5
num_epochs: 3
scheduler_step_size: 1
grad_accumulation_steps: 4
eval_every: 500
early_stopping_patience: 3
save_best_model: true

tasks:
  sentiment_analysis:
    dataset_name: MonoHime/ru_sentiment_dataset
    subset_name: default
    text_column: text
    label_column: sentiment
    metrics: [accuracy, f1_macro]
  classification:
    dataset_name: ag_news
    subset_name: default
    text_column: text
    label_column: label
    metrics: [accuracy, f1_macro]
  ner:
    dataset_name: iluvvatar/NEREL
    subset_name: data
    text_column: text
    label_column: entities
    metrics: [precision, recall, f1]
  summarization:
    dataset_name: zjkarina/matreshka
    subset_name: default
    input_columns: [role, dialog, persona]
    summary_column: summary
    metrics: [rouge, bleu]
  question_answering:
    dataset_name: RussianNLP/russian_super_glue
    subset_name: rucos
    input_columns: passage
    question_column: question
    metrics: [exact_match, f1]
"""

In [None]:
# Определение датасета
class MultiTaskDataset(Dataset):
    def __init__(self, datasets, task_to_datakey):
        self.datasets = datasets
        self.task_to_datakey = task_to_datakey

    def __len__(self):
        return sum(len(dataset) for dataset in self.datasets.values())

    def __getitem__(self, idx):
        task = self.task_to_datakey[idx]
        datakey = self.task_to_datakey[task]
        return self.datasets[task][datakey[idx]]

In [None]:
# Загрузка конфигурации
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Загрузка датасетов
datasets = {}
for task, task_config in config['tasks'].items():
    dataset = load_dataset(task_config['dataset_name'],
                           task_config['subset_name'])
    datasets[task] = dataset

# Токенизация
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

def preprocess_function(examples, task_config):
    # Препроцессинг в зависимости от задачи
    if task == 'classification':
        return tokenizer(examples[task_config['text_column']], truncation=True)

    elif task == 'summarization':
        # Объединение нескольких колонок входных данных в одну строку
        input_text = [' '.join([examples[col][i] for col in task_config['input_columns']])
                      for i in range(len(examples[task_config['input_columns'][0]]))]
        return tokenizer(input_text,
                         examples[task_config['summary_column']],
                         truncation=True)

    elif task == 'ner':
        tokenized_examples = tokenizer(
            examples[task_config['text_column']],
            truncation=True,
            is_split_into_words=True
            )
        labels = []
        for i, label in enumerate(examples[task_config['label_column']]):
            word_ids = tokenized_examples.word_ids(batch_index=i)
            label_ids = []
            for word_id in word_ids:
                if word_id is None:
                    label_ids.append(-100)
                else:
                    label_ids.append(label[word_id])
            labels.append(label_ids)
        tokenized_examples['labels'] = labels
        return tokenized_examples

    elif task == 'question_answering':
        # Объединение нескольких колонок входных данных в одну строку
        input_text = [' '.join([examples[col][i] for col in task_config['input_columns']])
                      for i in range(len(examples[task_config['input_columns'][0]]))]
        return tokenizer(
            examples[task_config['question_column']],
            input_text,
            truncation='only_second',
            max_length=task_config['max_length'],
            stride=task_config['stride'],
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding='max_length'
            )

    elif task == 'sentiment_analysis':
        return tokenizer(examples[task_config['text_column']],
                         truncation=True,
                         max_length=config['max_length'],
                         padding='max_length')

tokenized_datasets = {}
for task, dataset in datasets.items():
    task_config = config['tasks'][task]
    tokenized_datasets[task] = dataset.map(
        lambda examples: preprocess_function(examples, task_config),
        batched=True,
        # remove_columns=dataset.column_names
    )

# Создание даталоадера
task_to_datakey = {
    task: list(dataset['train'].keys()) for task, dataset in tokenized_datasets.items()
    }
train_dataset = MultiTaskDataset(tokenized_datasets,
                                 task_to_datakey)
train_dataloader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True)

# Загрузка предобученной модели
if config['model_type'] == 'classification':
    model = AutoModelForSequenceClassification.from_pretrained(
        config['model_name'], num_labels=config['num_labels']
        )
elif config['model_type'] == 'summarization':
    model = AutoModelForSeq2SeqLM.from_pretrained(config['model_name'])
elif config['model_type'] == 'question_answering':
    model = AutoModelForQuestionAnswering.from_pretrained(config['model_name'])
elif config['model_type'] == 'ner':
    model = AutoModelForTokenClassification.from_pretrained(
        config['model_name'], num_labels=config['num_labels']
    )
elif config['model_type'] == 'sentiment_analysis':
    model = AutoModelForSequenceClassification.from_pretrained(
        config['model_name'], num_labels=config['num_labels']
    )
else:
    print("Use only params in 'model_type': classification/summarization/question_answering/ner/sentiment_analysis")
    exit()

# Оптимизатор и планировщик
optimizer = optim.AdamW(model.parameters(),
                        lr=config['learning_rate'])
scheduler = optim.lr_scheduler.StepLR(optimizer,
                                      step_size=config['scheduler_step_size'])

# Инициализация объекта Rouge
rouge = Rouge()

In [None]:
# Цикл обучения
results = []

for epoch in range(config['num_epochs']):

    train_pbar = tqdm(
        train_dataloader,
        desc=f"Epoch {epoch + 1} [Training]",
        position=0,
        leave=True
    )

    for batch in train_dataloader:
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(config['device'])
        attention_mask = batch['attention_mask'].to(config['device'])
        labels = batch['labels'].to(config['device'])

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        scheduler.step()

        train_pbar.update(1)
        train_pbar.set_postfix({'loss': loss.item()})

    # Оценка на валидации
    val_scores = {}
    for task, dataset in tokenized_datasets.items():
        if 'validation' in dataset:
            if task == 'classification':
                val_preds = []
                val_labels = []
                for batch in DataLoader(dataset['validation'],
                                        batch_size=config['batch_size']):
                    input_ids = batch['input_ids'].to(config['device'])
                    attention_mask = batch['attention_mask'].to(config['device'])
                    labels = batch['labels'].to(config['device'])

                    with torch.no_grad():
                        outputs = model(input_ids,
                                        attention_mask=attention_mask)

                    preds = torch.argmax(outputs.logits, dim=-1)
                    val_preds.extend(preds.tolist())
                    val_labels.extend(labels.tolist())

                accuracy = accuracy_score(val_labels, val_preds)
                f1 = f1_score(val_labels, val_preds, average='macro')
                val_scores[f'{task}_accuracy'] = accuracy
                val_scores[f'{task}_f1'] = f1

            elif task == 'summarization':
                val_preds = []
                val_labels = []
                for batch in DataLoader(
                    dataset['validation'], batch_size=config['batch_size']
                    ):
                    input_ids = batch['input_ids'].to(config['device'])
                    attention_mask = batch['attention_mask'].to(config['device'])
                    labels = batch['labels'].to(config['device'])

                    with torch.no_grad():
                        outputs = model.generate(
                            input_ids,
                            attention_mask=attention_mask
                            )

                    preds = tokenizer.batch_decode(
                        outputs, skip_special_tokens=True
                        )
                    labels = tokenizer.batch_decode(
                        labels, skip_special_tokens=True
                        )
                    val_preds.extend(preds)
                    val_labels.extend(labels)

                # Вычисление метрик ROUGE
                rouge_scores = rouge.get_scores(
                    val_preds, val_labels, avg=True
                    )
                rouge_1 = rouge_scores['rouge-1']['f']
                rouge_2 = rouge_scores['rouge-2']['f']
                rouge_l = rouge_scores['rouge-l']['f']

                val_scores[f'{task}_rouge_1'] = rouge_1
                val_scores[f'{task}_rouge_2'] = rouge_2
                val_scores[f'{task}_rouge_l'] = rouge_l

                # Вычисление метрики BLEU
                bleu_scores = corpus_bleu(
                    [[ref] for ref in val_labels], val_preds
                    )
                val_scores[f'{task}_bleu'] = bleu_scores

            elif task == 'ner':
                val_preds = []
                val_labels = []
                for batch in DataLoader(dataset['validation'], batch_size=config['batch_size']):
                    input_ids = batch['input_ids'].to(config['device'])
                    attention_mask = batch['attention_mask'].to(config['device'])
                    labels = batch['labels'].to(config['device'])

                    with torch.no_grad():
                        outputs = model(input_ids, attention_mask=attention_mask)

                    preds = torch.argmax(outputs.logits, dim=-1)
                    val_preds.extend(preds[attention_mask.bool()].tolist())
                    val_labels.extend(labels[attention_mask.bool()].tolist())

                val_preds = [model.config.id2label[pred] for pred in val_preds]
                val_labels = [model.config.id2label[label] for label in val_labels]

                precision = precision_score(val_labels, val_preds, average='micro')
                recall = recall_score(val_labels, val_preds, average='micro')
                f1 = f1_score(val_labels, val_preds, average='micro')
                val_scores[f'{task}_precision'] = precision
                val_scores[f'{task}_recall'] = recall
                val_scores[f'{task}_f1'] = f1

            elif task == 'sentiment_analysis':
                val_preds = []
                val_labels = []
                for batch in DataLoader(dataset['validation'], batch_size=config['batch_size']):
                    input_ids = batch['input_ids'].to(config['device'])
                    attention_mask = batch['attention_mask'].to(config['device'])
                    labels = batch['labels'].to(config['device'])

                    with torch.no_grad():
                        outputs = model(input_ids, attention_mask=attention_mask)

                    preds = torch.argmax(outputs.logits, dim=-1)
                    val_preds.extend(preds.tolist())
                    val_labels.extend(labels.tolist())

                accuracy = accuracy_score(val_labels, val_preds)
                f1 = f1_score(val_labels, val_preds, average='macro')
                val_scores[f'{task}_accuracy'] = accuracy
                val_scores[f'{task}_f1'] = f1

            elif task == 'question_answering':
                val_preds = []
                val_labels = []
                for batch in DataLoader(dataset['validation'], batch_size=config['batch_size']):
                    input_ids = batch['input_ids'].to(config['device'])
                    attention_mask = batch['attention_mask'].to(config['device'])
                    start_positions = batch['start_positions'].to(config['device'])
                    end_positions = batch['end_positions'].to(config['device'])

                    with torch.no_grad():
                        outputs = model(input_ids, attention_mask=attention_mask)

        # Сохранение результатов
        results.append({'epoch': epoch, 'loss': loss.item(), **val_scores})
        print(
            f"Epoch {epoch+1}/{config['num_epochs']}, Loss: {loss.item()}, Validation Scores: {val_scores}"
            )

    # Сохранение результатов в CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv('results.csv', index=False)

    # Сохранение лучшей модели
    best_model_path = f"best_model_{config['model_name']}.pth"
    torch.save(model.state_dict(), best_model_path)