In [1]:
import os
import nltk
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, T5Tokenizer
import numpy as np
import torch
import pandas as pd

from bert_score import score as bert_score_compute
from rouge import Rouge
from nltk.translate.meteor_score import meteor_score
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.chrf_score import corpus_chrf

nltk.download('punkt_tab', quiet=True)
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)

  from .autonotebook import tqdm as notebook_tqdm


True

# Загрузка и подготовка датасета

In [2]:
VALID_FILE = '/content/drive/MyDrive/final_data/test.csv'
MODEL_NAME = "ai-forever/ruT5-base"
OUTPUT_DIR = "/content/drive/MyDrive/checkpoints/ruT5_sum"

MAX_SOURCE_LENGTH = 512
MAX_TARGET_LENGTH = 64
BATCH_SIZE_PER_DEVICE = 2

In [4]:
rouge = Rouge()

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [10]:
def preprocess_function(examples):
    texts = [str(doc) if doc is not None else "" for doc in examples["text"]]
    summaries = [str(doc) if doc is not None else "" for doc in examples["summary"]]

    model_inputs = tokenizer(texts, max_length=MAX_SOURCE_LENGTH, truncation=True)
    labels = tokenizer(text_target=summaries, max_length=MAX_TARGET_LENGTH, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [12]:
data_files_eval = {"validation": VALID_FILE}
raw_eval_datasets = load_dataset("csv", data_files=data_files_eval)
eval_dataset = raw_eval_datasets["validation"]
print(f"Количество примеров: {len(eval_dataset)}")

Generating validation split: 0 examples [00:00, ? examples/s]

Количество примеров: 3228


In [13]:
tokenized_eval_dataset = eval_dataset.map(preprocess_function,batched=True)

Map:   0%|          | 0/3228 [00:00<?, ? examples/s]

# Подсчёт метрик

In [None]:
def compute_metrics(eval_preds, current_tokenizer):
    # Распаковываем предсказания и метки
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    # Заменяем NaN на pad_token_id
    preds_sanitized = np.nan_to_num(preds, nan=current_tokenizer.pad_token_id, posinf=current_tokenizer.pad_token_id, neginf=current_tokenizer.pad_token_id)
    vocab_size = current_tokenizer.vocab_size
    preds_sanitized_int = preds_sanitized.astype(np.int64)

    # Заменяем токены, выходящие за пределы словаря, на pad_token_id
    preds_final_sanitized = np.where((preds_sanitized_int >= 0) & (preds_sanitized_int < vocab_size), preds_sanitized_int, current_tokenizer.pad_token_id).astype(np.int32)

    # Декодируем предсказания в тексты, пропуская спецтокены
    decoded_preds_raw = current_tokenizer.batch_decode(preds_final_sanitized, skip_special_tokens=True)

    # Для меток используем pad_token_id вместо -100, чтобы декодировать
    labels_processed = np.where(labels != -100, labels, current_tokenizer.pad_token_id)
    decoded_labels_raw = current_tokenizer.batch_decode(labels_processed, skip_special_tokens=True)

    filtered_preds = []
    filtered_labels = []

    tokenized_preds = []
    tokenized_refs = []
    
    # Убираем пустые строки, одновременно собирая токенизированные списки
    for pred_str, label_str in zip(decoded_preds_raw, decoded_labels_raw):
        p_strip = pred_str.strip()
        l_strip = label_str.strip()
        if p_strip and l_strip:
            filtered_preds.append(p_strip)
            filtered_labels.append(l_strip)
            tokenized_preds.append(nltk.word_tokenize(p_strip, language='russian'))
            tokenized_refs.append([nltk.word_tokenize(l_strip, language='russian')])
        elif l_strip:
            filtered_preds.append("")
            filtered_labels.append(l_strip)

    metrics = {}

    # Средняя длина сгенерированного текста
    prediction_lens = [np.count_nonzero(p != current_tokenizer.pad_token_id) for p in preds_final_sanitized]
    metrics["gen_len"] = round(np.mean(prediction_lens) if prediction_lens else 0.0, 4)

    # Rouge
    rouge_scores_dict = rouge.get_scores(hyps=filtered_preds, refs=filtered_labels, avg=True)
    metrics["rouge1_f"] = round(rouge_scores_dict.get('rouge-1', {}).get('f', 0.0) * 100, 4)
    metrics["rouge2_f"] = round(rouge_scores_dict.get('rouge-2', {}).get('f', 0.0) * 100, 4)
    metrics["rougel_f"] = round(rouge_scores_dict.get('rouge-l', {}).get('f', 0.0) * 100, 4)

    # BERTScore
    bert_device = "cuda" if torch.cuda.is_available() else "cpu"
    _, _, F1_bert = bert_score_compute(
        cands=filtered_preds, refs=filtered_labels, lang="ru",
        verbose=False, device=bert_device, batch_size=16)
    f1_bert_mean = F1_bert.mean().item()
    metrics["bert_score_f1"] = round(f1_bert_mean * 100, 4) if not np.isnan(f1_bert_mean) else 0.0

    # CHRF++
    chrf_score_val = corpus_chrf(
        references=filtered_labels,
        hypotheses=filtered_preds,
        beta=1.0
    )
    metrics["chrf++"] = round(chrf_score_val * 100, 4)

    # BLEU
    if tokenized_preds and tokenized_refs and len(tokenized_preds) == len(tokenized_refs) and len(tokenized_preds) > 0:
        bleu_score_val = corpus_bleu(
            list_of_references=tokenized_refs,
            hypotheses=tokenized_preds
        )
        metrics["bleu"] = round(bleu_score_val * 100, 4)
    else:
        metrics["bleu"] = 0.0

    # METEOR
    meteor_scores_list = []
    for pred_tokens, ref_tokens_list  in zip(tokenized_preds, tokenized_refs):
        meteor_scores_list.append(meteor_score(ref_tokens_list, pred_tokens))

    metrics["meteor"] = round(np.mean(meteor_scores_list) * 100, 4) if meteor_scores_list else 0.0

    return metrics

In [None]:
checkpoint_folders = sorted(
    [os.path.join(OUTPUT_DIR, d) for d in os.listdir(OUTPUT_DIR) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(OUTPUT_DIR, d))],
    key=lambda x: int(x.split('-')[-1]))

checkpoint_folders

['/content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-1',
 '/content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-5',
 '/content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-10',
 '/content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-15',
 '/content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-20']

In [None]:
all_results = []
device = "cuda" if torch.cuda.is_available() else "cpu"

for ckpt_path in checkpoint_folders:
    model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_path).to(device)
    model.eval()

    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        label_pad_token_id=tokenizer.pad_token_id,
        pad_to_multiple_of=8
    )

    training_args = Seq2SeqTrainingArguments(
        output_dir="./temp_eval_output",
        per_device_eval_batch_size=BATCH_SIZE_PER_DEVICE,
        predict_with_generate=True,
        generation_max_length=MAX_TARGET_LENGTH,
        generation_num_beams=4,
        fp16=torch.cuda.is_available(),
        report_to="none"
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    predictions_output = trainer.predict(
        test_dataset=tokenized_eval_dataset,
        metric_key_prefix="eval"
    )

    metrics = compute_metrics((predictions_output.predictions, predictions_output.label_ids), tokenizer)

    print(f"Метрики для {os.path.basename(ckpt_path)}:")
    for k, v in metrics.items():
        print(f"  {k}: {v}")

    result_entry = {"checkpoint": os.path.basename(ckpt_path)}
    result_entry.update(metrics)
    all_results.append(result_entry)

Оценка чекпоинта: /content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-1


  trainer = Seq2SeqTrainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

Метрики для checkpoint-1:
  gen_len: 28.5254
  rouge1_f: 27.5093
  rouge2_f: 12.8689
  rougel_f: 24.9481
  bert_score_f1: 76.9721
  chrf++: 36.8263
  bleu: 8.7239
  meteor: 25.7877
Оценка чекпоинта: /content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-5


  trainer = Seq2SeqTrainer(


Метрики для checkpoint-5:
  gen_len: 29.5675
  rouge1_f: 29.6157
  rouge2_f: 14.3995
  rougel_f: 26.8771
  bert_score_f1: 77.9207
  chrf++: 39.0962
  bleu: 10.2914
  meteor: 28.3542
Оценка чекпоинта: /content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-10


  trainer = Seq2SeqTrainer(


Метрики для checkpoint-10:
  gen_len: 29.5421
  rouge1_f: 30.4751
  rouge2_f: 15.0693
  rougel_f: 27.705
  bert_score_f1: 78.2533
  chrf++: 39.8055
  bleu: 10.7023
  meteor: 29.0795
Оценка чекпоинта: /content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-15


  trainer = Seq2SeqTrainer(


Метрики для checkpoint-15:
  gen_len: 29.6419
  rouge1_f: 30.607
  rouge2_f: 15.1734
  rougel_f: 27.8276
  bert_score_f1: 78.3665
  chrf++: 39.9895
  bleu: 10.7479
  meteor: 29.2643
Оценка чекпоинта: /content/drive/MyDrive/checkpoints/ruT5_sum/checkpoint-20


  trainer = Seq2SeqTrainer(


Метрики для checkpoint-20:
  gen_len: 29.7094
  rouge1_f: 30.7346
  rouge2_f: 15.2225
  rougel_f: 27.9426
  bert_score_f1: 78.3572
  chrf++: 40.0573
  bleu: 10.9116
  meteor: 29.4225


In [18]:
results_df = pd.DataFrame(all_results)
results_df

Unnamed: 0,checkpoint,gen_len,rouge1_f,rouge2_f,rougel_f,bert_score_f1,chrf++,bleu,meteor
0,checkpoint-1,28.5254,27.5093,12.8689,24.9481,76.9721,36.8263,8.7239,25.7877
1,checkpoint-5,29.5675,29.6157,14.3995,26.8771,77.9207,39.0962,10.2914,28.3542
2,checkpoint-10,29.5421,30.4751,15.0693,27.705,78.2533,39.8055,10.7023,29.0795
3,checkpoint-15,29.6419,30.607,15.1734,27.8276,78.3665,39.9895,10.7479,29.2643
4,checkpoint-20,29.7094,30.7346,15.2225,27.9426,78.3572,40.0573,10.9116,29.4225


In [None]:
# Сохраняем для дальнейшего анализа
results_df.to_csv("metrics_ruT5.csv", index=False) 