# SBERT retrieval

Self-contained implementation and evaluation. GPU recommended.


## Setup


In [1]:
from pathlib import Path
import json

DATASET_PATH = Path('..') / 'data' / 'annotations_dataset_full.json'
if not DATASET_PATH.exists():
    DATASET_PATH = Path('..') / 'data' / 'annotations_dataset_new.json'
print('Dataset:', DATASET_PATH)


Dataset: ../data/annotations_dataset_new.json


## Dataset stats


In [2]:
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    data = json.load(f)

songs = len(data)
annotations = sum(len(s.get('annotations', [])) for s in data)
print('Songs:', songs)
print('Annotations:', annotations)
print('Avg annotations per song:', round(annotations / songs, 2) if songs else 0)


Songs: 3291
Annotations: 22220
Avg annotations per song: 6.75


## Load pairs


In [3]:
fragments = []
annotations = []
metadata = []

for song in data:
    for ann in song.get('annotations', []):
        fragments.append(ann.get('fragment', ''))
        annotations.append(ann.get('annotation', ''))
        metadata.append({
            'artist': song.get('artist', ''),
            'title': song.get('title', ''),
            'votes': ann.get('votes', 0),
        })

print('Pairs:', len(fragments))

MAX_EXAMPLES = 2000  # set None for full run
if MAX_EXAMPLES:
    import random
    idx = list(range(len(fragments)))
    random.seed(42)
    random.shuffle(idx)
    idx = idx[:MAX_EXAMPLES]
    fragments = [fragments[i] for i in idx]
    annotations = [annotations[i] for i in idx]
    metadata = [metadata[i] for i in idx]
    print('Using subset:', len(fragments))


Pairs: 22220
Using subset: 2000


## Implementation


In [4]:
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from rouge_score import rouge_scorer
from sacrebleu import corpus_bleu

class SBERTRetriever:
    def __init__(self, fragments, annotations, metadata, model_name='paraphrase-multilingual-MiniLM-L12-v2'):
        self.fragments = fragments
        self.annotations = annotations
        self.metadata = metadata
        self.model = SentenceTransformer(model_name)
        self.embeddings = self.model.encode(self.fragments, show_progress_bar=True, convert_to_numpy=True)

    def find_similar(self, query, top_k=3):
        query_emb = self.model.encode([query], convert_to_numpy=True)
        sims = cosine_similarity(query_emb, self.embeddings)[0]
        top_indices = np.argsort(sims)[-top_k:][::-1]
        return [
            {
                'fragment': self.fragments[idx],
                'annotation': self.annotations[idx],
                'similarity': float(sims[idx]),
                'artist': self.metadata[idx]['artist'],
                'title': self.metadata[idx]['title'],
                'votes': self.metadata[idx]['votes'],
            }
            for idx in top_indices
        ]

def evaluate_sbert(fragments, annotations, metadata, model_name='paraphrase-multilingual-MiniLM-L12-v2'):
    retriever = SBERTRetriever(fragments, annotations, metadata, model_name=model_name)
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)

    correct_top1 = 0
    correct_top3 = 0
    predictions = []
    references = []
    similarities = []

    for i, fragment in enumerate(fragments):
        query_emb = retriever.embeddings[i:i+1]
        sims = cosine_similarity(query_emb, retriever.embeddings)[0]
        sims[i] = -1e9
        top_indices = np.argsort(sims)[-3:][::-1]

        predicted = annotations[top_indices[0]]
        true_annotation = annotations[i]

        predictions.append(predicted)
        references.append(true_annotation)
        similarities.append(sims[top_indices[0]])

        if predicted == true_annotation:
            correct_top1 += 1
        if true_annotation in [annotations[idx] for idx in top_indices]:
            correct_top3 += 1

    rouge_scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}
    for pred, ref in zip(predictions, references):
        scores = scorer.score(ref, pred)
        rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
        rouge_scores['rouge2'].append(scores['rouge2'].fmeasure)
        rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)

    bleu = corpus_bleu(predictions, [[r] for r in references])

    return {
        'method': 'SBERT',
        'top1_accuracy': correct_top1 / len(fragments) if fragments else 0.0,
        'top3_accuracy': correct_top3 / len(fragments) if fragments else 0.0,
        'avg_similarity': float(np.mean(similarities)) if similarities else 0.0,
        'rouge1': float(np.mean(rouge_scores['rouge1'])) if rouge_scores['rouge1'] else 0.0,
        'rouge2': float(np.mean(rouge_scores['rouge2'])) if rouge_scores['rouge2'] else 0.0,
        'rougeL': float(np.mean(rouge_scores['rougeL'])) if rouge_scores['rougeL'] else 0.0,
        'bleu': bleu.score,
        'total_examples': len(fragments),
    }


  from .autonotebook import tqdm as notebook_tqdm


## Evaluate


In [5]:
sbert_results = evaluate_sbert(fragments, annotations, metadata)
sbert_results


Batches: 100%|██████████| 63/63 [00:02<00:00, 24.66it/s]


{'method': 'SBERT',
 'top1_accuracy': 0.008,
 'top3_accuracy': 0.008,
 'avg_similarity': 0.6854549646377563,
 'rouge1': 0.00870148379131968,
 'rouge2': 0.004899228911185434,
 'rougeL': 0.00870148379131968,
 'bleu': 100.00000000000004,
 'total_examples': 2000}

In [6]:
import json
from pathlib import Path

out_path = Path('..') / 'data' / 'sbert_results.json'
with open(out_path, 'w', encoding='utf-8') as f:
    json.dump(sbert_results, f, ensure_ascii=False, indent=2)
print('Saved:', out_path)

Saved: ../data/sbert_results.json


## Demo query


In [7]:
retriever = SBERTRetriever(fragments, annotations, metadata)
retriever.find_similar('Я вижу город под подошвой', top_k=2)


Batches: 100%|██████████| 63/63 [00:02<00:00, 30.74it/s]


[{'fragment': 'Город под подошвой\n Город под подошвой — этот город под подошвой',
  'annotation': '«Город под подошвой» — песня российского рэпера Оксимирона (Oxxxymiron).',
  'similarity': 0.8216304779052734,
  'artist': 'CMH',
  'title': 'GAZZ',
  'votes': 3},
 {'fragment': 'Еду в центр, это город дорог',
  'annotation': 'Игра слов отсылает нас к треку группы CENTR – Город дорог',
  'similarity': 0.6862443089485168,
  'artist': 'OG Buda',
  'title': 'Дзагоев (Dzagoev)',
  'votes': 3}]