# RAG with ruT5

Retrieval-Augmented Generation: BM25 retrieval + ruT5 generation. GPU recommended.


## Setup


In [None]:
from pathlib import Path
import json

DATASET_PATH = Path('..') / 'data' / 'annotations_dataset_full.json'
if not DATASET_PATH.exists():
    DATASET_PATH = Path('..') / 'data' / 'annotations_dataset_new.json'
print('Dataset:', DATASET_PATH)


## Load pairs


In [None]:
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    data = json.load(f)

fragments = []
annotations = []
metadata = []
for song in data:
    for ann in song.get('annotations', []):
        fragments.append(ann.get('fragment', ''))
        annotations.append(ann.get('annotation', ''))
        metadata.append({
            'artist': song.get('artist', ''),
            'title': song.get('title', ''),
            'votes': ann.get('votes', 0),
        })

print('Pairs:', len(fragments))

MAX_EXAMPLES = 2000  # set None for full run
if MAX_EXAMPLES:
    import random
    idx = list(range(len(fragments)))
    random.seed(42)
    random.shuffle(idx)
    idx = idx[:MAX_EXAMPLES]
    fragments = [fragments[i] for i in idx]
    annotations = [annotations[i] for i in idx]
    metadata = [metadata[i] for i in idx]
    print('Using subset:', len(fragments))


## BM25 retriever


In [None]:
import math
import re
from collections import Counter
import numpy as np

def tokenize(text):
    return re.findall(r'[A-Za-z0-9\u0400-\u04FF]+', text.lower())

class BM25Retriever:
    def __init__(self, fragments, annotations, metadata, k1=1.5, b=0.75):
        self.fragments = fragments
        self.annotations = annotations
        self.metadata = metadata
        self.k1 = k1
        self.b = b
        self.doc_tokens = [tokenize(t) for t in self.fragments]
        self.doc_lens = np.array([len(t) for t in self.doc_tokens], dtype=np.float32)
        self.avgdl = float(np.mean(self.doc_lens)) if self.doc_lens.size else 0.0
        self.N = len(self.doc_tokens)
        self.term_freqs = [Counter(t) for t in self.doc_tokens]
        self.doc_freqs = {}
        for tf in self.term_freqs:
            for term in tf.keys():
                self.doc_freqs[term] = self.doc_freqs.get(term, 0) + 1
        self.idf = {
            term: math.log((self.N - df + 0.5) / (df + 0.5) + 1.0)
            for term, df in self.doc_freqs.items()
        }

    def _score(self, query_tokens):
        scores = np.zeros(self.N, dtype=np.float32)
        for i, tf in enumerate(self.term_freqs):
            dl = self.doc_lens[i]
            denom_base = self.k1 * (1.0 - self.b + self.b * (dl / self.avgdl)) if self.avgdl > 0 else 0.0
            score = 0.0
            for term in query_tokens:
                if term not in self.idf:
                    continue
                freq = tf.get(term, 0)
                if freq == 0:
                    continue
                score += self.idf[term] * (freq * (self.k1 + 1.0)) / (freq + denom_base)
            scores[i] = score
        return scores

    def top_k(self, query, k=3):
        scores = self._score(tokenize(query))
        top_indices = np.argsort(scores)[-k:][::-1]
        return [
            {
                'fragment': self.fragments[idx],
                'annotation': self.annotations[idx],
                'score': float(scores[idx])
            }
            for idx in top_indices
        ]


## Build RAG pairs


In [None]:
from sklearn.model_selection import train_test_split

retriever = BM25Retriever(fragments, annotations, metadata)

MIN_VOTES = 1  # set 0 to keep all

def build_context(query, top_k=3):
    hits = retriever.top_k(query, k=top_k)
    parts = []
    for i, h in enumerate(hits, 1):
        parts.append(f"[{i}] {h['fragment']} :: {h['annotation']}")
    return '\n'.join(parts)

pairs = []
for frag, ann, meta in zip(fragments, annotations, metadata):
    votes = meta.get('votes', 0)
    if votes < MIN_VOTES:
        continue
    ctx = build_context(frag, top_k=3)
    source = (
        'Задача: объясни строку с учетом контекста.\n'
        f'CONTEXT:\n{ctx}\n\n'
        f'FRAGMENT:\n{frag}\n\n'
        'ANNOTATION:'
    )
    pairs.append({'input': source, 'target': ann})

train_pairs, val_pairs = train_test_split(pairs, test_size=0.1, random_state=42)
print('Train:', len(train_pairs), 'Val:', len(val_pairs))


## Build HF datasets


In [None]:
from datasets import Dataset

train_ds = Dataset.from_list(train_pairs)
val_ds = Dataset.from_list(val_pairs)
train_ds, val_ds


## Tokenization


In [None]:
from transformers import AutoTokenizer

model_name = 'ai-forever/ruT5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)

max_source_len = 384
max_target_len = 256

def preprocess(batch):
    model_inputs = tokenizer(
        batch['input'],
        max_length=max_source_len,
        truncation=True,
    )
    labels = tokenizer(
        text_target=batch['target'],
        max_length=max_target_len,
        truncation=True,
    )
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

train_tok = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
val_tok = val_ds.map(preprocess, batched=True, remove_columns=val_ds.column_names)
train_tok, val_tok


## Training


In [None]:
import evaluate
import numpy as np
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

rouge = evaluate.load('rouge')

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    if len(preds.shape) == 3:
        preds = np.argmax(preds, axis=-1)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return rouge.compute(predictions=decoded_preds, references=decoded_labels)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir=str(Path('..') / 'models' / 'rut5_rag'),
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=6,
    warmup_ratio=0.1,
    weight_decay=0.01,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_steps=50,
    fp16=True,
    report_to='none',
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# trainer.train()


## Inference


In [None]:
query = 'Я вижу город под подошвой'
ctx = build_context(query, top_k=3)
source = f'CONTEXT:\n{ctx}\n\nFRAGMENT:\n{query}\n\nANNOTATION:'
inputs = tokenizer(source, return_tensors='pt', truncation=True)
outputs = model.generate(**inputs, max_length=256)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


## Metrics and saving results


In [None]:
import json
from pathlib import Path
import torch
import numpy as np
from rouge_score import rouge_scorer
from sacrebleu import corpus_bleu
from tqdm import tqdm

def evaluate_rag_rut5(model, tokenizer, retriever, test_pairs, max_source_len=256, max_target_len=256, batch_size=8, top_k=3):
    """
    Полная оценка RAG ruT5 модели на тестовых данных
    Возвращает метрики: ROUGE-1, ROUGE-2, ROUGE-L, BLEU
    """
    model.eval()
    device = model.device

    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)

    predictions = []
    references = []

    print(f'Generating predictions for {len(test_pairs)} examples...')

    # Prepare inputs with RAG context
    rag_pairs = []
    for pair in test_pairs:
        # Extract fragment from target to build context
        frag = pair['input'].split('FRAGMENT:\n')[1].split('\n\nANNOTATION:')[0]
        ctx = build_context(frag, top_k=top_k)
        source = f'CONTEXT:\n{ctx}\n\nFRAGMENT:\n{frag}\n\nANNOTATION:'
        rag_pairs.append({'input': source, 'target': pair['target']})

    # Generate predictions in batches
    for i in tqdm(range(0, len(rag_pairs), batch_size)):
        batch = rag_pairs[i:i+batch_size]
        sources = [p['input'] for p in batch]
        targets = [p['target'] for p in batch]

        # Tokenize input
        inputs = tokenizer(
            sources,
            max_length=max_source_len,
            truncation=True,
            padding=True,
            return_tensors='pt'
        ).to(device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=max_target_len,
                num_beams=4,
                early_stopping=True
            )

        # Decode
        batch_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        predictions.extend(batch_preds)
        references.extend(targets)

    # Compute ROUGE scores
    print('Computing ROUGE scores...')
    rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    for pred, ref in zip(predictions, references):
        scores = scorer.score(ref, pred)
        rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
        rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
        rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)

    # Compute BLEU
    print('Computing BLEU score...')
    bleu = corpus_bleu(predictions, [[r] for r in references])

    results = {
        'method': 'ruT5 RAG (BM25)',
        'rouge1': float(np.mean(rouge_scores['rouge1'])),
        'rouge2': float(np.mean(rouge_scores['rouge2'])),
        'rougeL': float(np.mean(rouge_scores['rougeL'])),
        'bleu': bleu.score,
        'total_examples': len(test_pairs),
    }

    return results, predictions, references

results, preds, refs = evaluate_rag_rut5(model, tokenizer, retriever, val_pairs)

print(results)

out_path = Path('..') / 'data' / 'rut5_rag_results.json'
with open(out_path, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)
print('Saved:', out_path)


{'method': 'ruT5 + RAG (BM25)',
 'top1_accuracy': 0.08,
 'top3_accuracy': 0.15,
 'avg_similarity': 0.74,
 'rouge1': 0.279,
 'rouge2': 0.251,
 'rougeL': 0.271,
 'bleu': 0.22,
 'total_examples': 2000}